Allow cloning UniversalGaussianDistribution (new_dist_like)

This commit is contained in:
Dominik Moritz Roth 2022-07-09 14:45:35 +02:00
parent c08ea1cb91
commit 4c4b12ee0e
2 changed files with 19 additions and 1 deletions

View File

@ -12,6 +12,7 @@ from stable_baselines3.common.distributions import Distribution as SB3_Distribut
from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.fakeModule import FakeModule
from ..misc.distTools import new_dist_like
# TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh)
@ -96,6 +97,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.distribution = None
def new_dist_like_me(self, mean, pseudo_chol):
p = self.distribution
np = new_dist_like(p, mean, pseudo_chol)
new = UniversalGaussianDistribution(self.action_dim)
new.par_strength = self.par_strength
new.cov_strength = self.cov_strength
new.par_type = self.par_type
new.enforce_positive_type = self.enforce_positive_type
new.prob_squashing_type = self.prob_squashing_type
new.distribution = np
return new
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
"""
Create the layers and parameter that represent the distribution:

View File

@ -2,6 +2,8 @@ import torch as th
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
def get_mean_and_chol(p, expand=False):
if isinstance(p, th.distributions.Normal):
@ -64,7 +66,9 @@ def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
def new_dist_like(orig_p, mean, chol):
if isinstance(orig_p, th.distributions.Normal):
if isinstance(orig_p, UniversalGaussianDistribution):
return orig_p.new_list_like_me(mean, chol)
elif isinstance(orig_p, th.distributions.Normal):
if orig_p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
return th.distributions.Normal(mean, chol)