Implemented Policies with Contextual Covariance

This commit is contained in:
Dominik Moritz Roth 2022-07-13 19:38:20 +02:00
parent 41e4170b2f
commit fae19509bc
3 changed files with 9 additions and 28 deletions

View File

@ -1 +1,2 @@
#TODO: License or such #TODO: License or such
from .distributions import *

View File

@ -2,10 +2,10 @@ import torch as th
from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution from ..distributions import UniversalGaussianDistribution, AnyDistribution
def get_mean_and_chol(p, expand=False): def get_mean_and_chol(p: AnyDistribution, expand=False):
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):
if expand: if expand:
return p.mean, th.diag_embed(p.stddev) return p.mean, th.diag_embed(p.stddev)
@ -19,7 +19,7 @@ def get_mean_and_chol(p, expand=False):
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
def get_mean_and_sqrt(p): def get_mean_and_sqrt(p: UniversalGaussianDistribution):
raise Exception('Not yet implemented...') raise Exception('Not yet implemented...')
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):
return p.mean, p.stddev return p.mean, p.stddev
@ -31,7 +31,7 @@ def get_mean_and_sqrt(p):
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
def get_cov(p): def get_cov(p: AnyDistribution):
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):
return th.diag_embed(p.variance) return th.diag_embed(p.variance)
elif isinstance(p, th.distributions.MultivariateNormal): elif isinstance(p, th.distributions.MultivariateNormal):
@ -42,7 +42,7 @@ def get_cov(p):
raise Exception('Dist-Type not implemented') raise Exception('Dist-Type not implemented')
def has_diag_cov(p, numerical_check=True): def has_diag_cov(p: AnyDistribution, numerical_check=True):
if isinstance(p, SB3_Distribution): if isinstance(p, SB3_Distribution):
return has_diag_cov(p.distribution, numerical_check=numerical_check) return has_diag_cov(p.distribution, numerical_check=numerical_check)
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):
@ -54,18 +54,18 @@ def has_diag_cov(p, numerical_check=True):
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov))) return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
def is_contextual(p): def is_contextual(p: AnyDistribution):
# TODO: Implement for UniveralGaussianDist # TODO: Implement for UniveralGaussianDist
return False return False
def get_diag_cov_vec(p, check_diag=True, numerical_check=True): def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
if check_diag and not has_diag_cov(p): if check_diag and not has_diag_cov(p):
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
return th.diagonal(get_cov(p), dim1=-2, dim2=-1) return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
def new_dist_like(orig_p, mean, chol): def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
if isinstance(orig_p, UniversalGaussianDistribution): if isinstance(orig_p, UniversalGaussianDistribution):
return orig_p.new_list_like_me(mean, chol) return orig_p.new_list_like_me(mean, chol)
elif isinstance(orig_p, th.distributions.Normal): elif isinstance(orig_p, th.distributions.Normal):

View File

@ -1,20 +0,0 @@
import torch as th
from torch import nn
class FakeModule(nn.Module):
"""
A torch.nn Module, that drops the input and returns a tensor given at initialization.
Gradients can pass through this Module and affect the given tensor.
"""
# In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
return self.tensor
def string(self):
return '<FakeModule: '+str(self.tensor)+'>'