Allow cloning UniversalGaussianDistribution (new_dist_like)
This commit is contained in:
parent
c08ea1cb91
commit
4c4b12ee0e
@ -12,6 +12,7 @@ from stable_baselines3.common.distributions import Distribution as SB3_Distribut
|
||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
|
||||
from ..misc.fakeModule import FakeModule
|
||||
from ..misc.distTools import new_dist_like
|
||||
|
||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||
# TODO: Support Squashed Dists (tanh)
|
||||
@ -96,6 +97,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
self.distribution = None
|
||||
|
||||
def new_dist_like_me(self, mean, pseudo_chol):
|
||||
p = self.distribution
|
||||
np = new_dist_like(p, mean, pseudo_chol)
|
||||
new = UniversalGaussianDistribution(self.action_dim)
|
||||
new.par_strength = self.par_strength
|
||||
new.cov_strength = self.cov_strength
|
||||
new.par_type = self.par_type
|
||||
new.enforce_positive_type = self.enforce_positive_type
|
||||
new.prob_squashing_type = self.prob_squashing_type
|
||||
new.distribution = np
|
||||
|
||||
return new
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
|
@ -2,6 +2,8 @@ import torch as th
|
||||
|
||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||
|
||||
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
|
||||
|
||||
|
||||
def get_mean_and_chol(p, expand=False):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
@ -64,7 +66,9 @@ def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
||||
|
||||
|
||||
def new_dist_like(orig_p, mean, chol):
|
||||
if isinstance(orig_p, th.distributions.Normal):
|
||||
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||
return orig_p.new_list_like_me(mean, chol)
|
||||
elif isinstance(orig_p, th.distributions.Normal):
|
||||
if orig_p.stddev.shape != chol.shape:
|
||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||
return th.distributions.Normal(mean, chol)
|
||||
|
Loading…
Reference in New Issue
Block a user