diff --git a/metastable_baselines/distributions/__init__.py b/metastable_baselines/distributions/__init__.py index 7f0891a..3c776c4 100644 --- a/metastable_baselines/distributions/__init__.py +++ b/metastable_baselines/distributions/__init__.py @@ -1 +1,2 @@ #TODO: License or such +from .distributions import * diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 2a78cf8..6d3e5ba 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -2,10 +2,10 @@ import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution -from metastable_baselines.distributions.distributions import UniversalGaussianDistribution +from ..distributions import UniversalGaussianDistribution, AnyDistribution -def get_mean_and_chol(p, expand=False): +def get_mean_and_chol(p: AnyDistribution, expand=False): if isinstance(p, th.distributions.Normal): if expand: return p.mean, th.diag_embed(p.stddev) @@ -19,7 +19,7 @@ def get_mean_and_chol(p, expand=False): raise Exception('Dist-Type not implemented') -def get_mean_and_sqrt(p): +def get_mean_and_sqrt(p: UniversalGaussianDistribution): raise Exception('Not yet implemented...') if isinstance(p, th.distributions.Normal): return p.mean, p.stddev @@ -31,7 +31,7 @@ def get_mean_and_sqrt(p): raise Exception('Dist-Type not implemented') -def get_cov(p): +def get_cov(p: AnyDistribution): if isinstance(p, th.distributions.Normal): return th.diag_embed(p.variance) elif isinstance(p, th.distributions.MultivariateNormal): @@ -42,7 +42,7 @@ def get_cov(p): raise Exception('Dist-Type not implemented') -def has_diag_cov(p, numerical_check=True): +def has_diag_cov(p: AnyDistribution, numerical_check=True): if isinstance(p, SB3_Distribution): return has_diag_cov(p.distribution, numerical_check=numerical_check) if isinstance(p, th.distributions.Normal): @@ -54,18 +54,18 @@ def has_diag_cov(p, numerical_check=True): return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov))) -def is_contextual(p): +def is_contextual(p: AnyDistribution): # TODO: Implement for UniveralGaussianDist return False -def get_diag_cov_vec(p, check_diag=True, numerical_check=True): +def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True): if check_diag and not has_diag_cov(p): raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') return th.diagonal(get_cov(p), dim1=-2, dim2=-1) -def new_dist_like(orig_p, mean, chol): +def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): if isinstance(orig_p, UniversalGaussianDistribution): return orig_p.new_list_like_me(mean, chol) elif isinstance(orig_p, th.distributions.Normal): diff --git a/metastable_baselines/misc/fakeModule.py b/metastable_baselines/misc/fakeModule.py deleted file mode 100644 index 72f2434..0000000 --- a/metastable_baselines/misc/fakeModule.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch as th -from torch import nn - - -class FakeModule(nn.Module): - """ - A torch.nn Module, that drops the input and returns a tensor given at initialization. - Gradients can pass through this Module and affect the given tensor. - """ - # In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule - - def __init__(self, tensor): - super().__init__() - self.tensor = tensor - - def forward(self, x): - return self.tensor - - def string(self): - return ''