From 4c4b12ee0e3b86b4ce05ac13278fbf0c8ef20dc5 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 9 Jul 2022 14:45:35 +0200 Subject: [PATCH] Allow cloning UniversalGaussianDistribution (new_dist_like) --- .../distributions/distributions.py | 14 ++++++++++++++ metastable_baselines/misc/distTools.py | 6 +++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 8e54218..45f7eec 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -12,6 +12,7 @@ from stable_baselines3.common.distributions import Distribution as SB3_Distribut from stable_baselines3.common.distributions import DiagGaussianDistribution from ..misc.fakeModule import FakeModule +from ..misc.distTools import new_dist_like # TODO: Integrate and Test what I currently have before adding more complexity # TODO: Support Squashed Dists (tanh) @@ -96,6 +97,19 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution = None + def new_dist_like_me(self, mean, pseudo_chol): + p = self.distribution + np = new_dist_like(p, mean, pseudo_chol) + new = UniversalGaussianDistribution(self.action_dim) + new.par_strength = self.par_strength + new.cov_strength = self.cov_strength + new.par_type = self.par_type + new.enforce_positive_type = self.enforce_positive_type + new.prob_squashing_type = self.prob_squashing_type + new.distribution = np + + return new + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: """ Create the layers and parameter that represent the distribution: diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 0e7d197..2a78cf8 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -2,6 +2,8 @@ import torch as th from stable_baselines3.common.distributions import Distribution as SB3_Distribution +from metastable_baselines.distributions.distributions import UniversalGaussianDistribution + def get_mean_and_chol(p, expand=False): if isinstance(p, th.distributions.Normal): @@ -64,7 +66,9 @@ def get_diag_cov_vec(p, check_diag=True, numerical_check=True): def new_dist_like(orig_p, mean, chol): - if isinstance(orig_p, th.distributions.Normal): + if isinstance(orig_p, UniversalGaussianDistribution): + return orig_p.new_list_like_me(mean, chol) + elif isinstance(orig_p, th.distributions.Normal): if orig_p.stddev.shape != chol.shape: chol = th.diagonal(chol, dim1=1, dim2=2) return th.distributions.Normal(mean, chol)