Implemented Policies with Contextual Covariance
This commit is contained in:
parent
41e4170b2f
commit
fae19509bc
@ -1 +1,2 @@
|
||||
#TODO: License or such
|
||||
from .distributions import *
|
||||
|
@ -2,10 +2,10 @@ import torch as th
|
||||
|
||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||
|
||||
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
|
||||
from ..distributions import UniversalGaussianDistribution, AnyDistribution
|
||||
|
||||
|
||||
def get_mean_and_chol(p, expand=False):
|
||||
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
if expand:
|
||||
return p.mean, th.diag_embed(p.stddev)
|
||||
@ -19,7 +19,7 @@ def get_mean_and_chol(p, expand=False):
|
||||
raise Exception('Dist-Type not implemented')
|
||||
|
||||
|
||||
def get_mean_and_sqrt(p):
|
||||
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
||||
raise Exception('Not yet implemented...')
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
return p.mean, p.stddev
|
||||
@ -31,7 +31,7 @@ def get_mean_and_sqrt(p):
|
||||
raise Exception('Dist-Type not implemented')
|
||||
|
||||
|
||||
def get_cov(p):
|
||||
def get_cov(p: AnyDistribution):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
return th.diag_embed(p.variance)
|
||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||
@ -42,7 +42,7 @@ def get_cov(p):
|
||||
raise Exception('Dist-Type not implemented')
|
||||
|
||||
|
||||
def has_diag_cov(p, numerical_check=True):
|
||||
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
||||
if isinstance(p, SB3_Distribution):
|
||||
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
@ -54,18 +54,18 @@ def has_diag_cov(p, numerical_check=True):
|
||||
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
||||
|
||||
|
||||
def is_contextual(p):
|
||||
def is_contextual(p: AnyDistribution):
|
||||
# TODO: Implement for UniveralGaussianDist
|
||||
return False
|
||||
|
||||
|
||||
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
||||
def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
|
||||
if check_diag and not has_diag_cov(p):
|
||||
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
||||
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
||||
|
||||
|
||||
def new_dist_like(orig_p, mean, chol):
|
||||
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||
return orig_p.new_list_like_me(mean, chol)
|
||||
elif isinstance(orig_p, th.distributions.Normal):
|
||||
|
@ -1,20 +0,0 @@
|
||||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
|
||||
class FakeModule(nn.Module):
|
||||
"""
|
||||
A torch.nn Module, that drops the input and returns a tensor given at initialization.
|
||||
Gradients can pass through this Module and affect the given tensor.
|
||||
"""
|
||||
# In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule
|
||||
|
||||
def __init__(self, tensor):
|
||||
super().__init__()
|
||||
self.tensor = tensor
|
||||
|
||||
def forward(self, x):
|
||||
return self.tensor
|
||||
|
||||
def string(self):
|
||||
return '<FakeModule: '+str(self.tensor)+'>'
|
Loading…
Reference in New Issue
Block a user