Implemented Policies with Contextual Covariance
This commit is contained in:
parent
41e4170b2f
commit
fae19509bc
@ -1 +1,2 @@
|
|||||||
#TODO: License or such
|
#TODO: License or such
|
||||||
|
from .distributions import *
|
||||||
|
@ -2,10 +2,10 @@ import torch as th
|
|||||||
|
|
||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
|
|
||||||
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
|
from ..distributions import UniversalGaussianDistribution, AnyDistribution
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_chol(p, expand=False):
|
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
if expand:
|
if expand:
|
||||||
return p.mean, th.diag_embed(p.stddev)
|
return p.mean, th.diag_embed(p.stddev)
|
||||||
@ -19,7 +19,7 @@ def get_mean_and_chol(p, expand=False):
|
|||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_sqrt(p):
|
def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
||||||
raise Exception('Not yet implemented...')
|
raise Exception('Not yet implemented...')
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
return p.mean, p.stddev
|
return p.mean, p.stddev
|
||||||
@ -31,7 +31,7 @@ def get_mean_and_sqrt(p):
|
|||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
def get_cov(p):
|
def get_cov(p: AnyDistribution):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
return th.diag_embed(p.variance)
|
return th.diag_embed(p.variance)
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
@ -42,7 +42,7 @@ def get_cov(p):
|
|||||||
raise Exception('Dist-Type not implemented')
|
raise Exception('Dist-Type not implemented')
|
||||||
|
|
||||||
|
|
||||||
def has_diag_cov(p, numerical_check=True):
|
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
||||||
if isinstance(p, SB3_Distribution):
|
if isinstance(p, SB3_Distribution):
|
||||||
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
@ -54,18 +54,18 @@ def has_diag_cov(p, numerical_check=True):
|
|||||||
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
||||||
|
|
||||||
|
|
||||||
def is_contextual(p):
|
def is_contextual(p: AnyDistribution):
|
||||||
# TODO: Implement for UniveralGaussianDist
|
# TODO: Implement for UniveralGaussianDist
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
|
||||||
if check_diag and not has_diag_cov(p):
|
if check_diag and not has_diag_cov(p):
|
||||||
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
||||||
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
|
||||||
|
|
||||||
|
|
||||||
def new_dist_like(orig_p, mean, chol):
|
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||||
if isinstance(orig_p, UniversalGaussianDistribution):
|
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||||
return orig_p.new_list_like_me(mean, chol)
|
return orig_p.new_list_like_me(mean, chol)
|
||||||
elif isinstance(orig_p, th.distributions.Normal):
|
elif isinstance(orig_p, th.distributions.Normal):
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
import torch as th
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class FakeModule(nn.Module):
|
|
||||||
"""
|
|
||||||
A torch.nn Module, that drops the input and returns a tensor given at initialization.
|
|
||||||
Gradients can pass through this Module and affect the given tensor.
|
|
||||||
"""
|
|
||||||
# In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule
|
|
||||||
|
|
||||||
def __init__(self, tensor):
|
|
||||||
super().__init__()
|
|
||||||
self.tensor = tensor
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.tensor
|
|
||||||
|
|
||||||
def string(self):
|
|
||||||
return '<FakeModule: '+str(self.tensor)+'>'
|
|
Loading…
Reference in New Issue
Block a user