Allow cloning UniversalGaussianDistribution (new_dist_like)
This commit is contained in:
parent
c08ea1cb91
commit
4c4b12ee0e
@ -12,6 +12,7 @@ from stable_baselines3.common.distributions import Distribution as SB3_Distribut
|
|||||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||||
|
|
||||||
from ..misc.fakeModule import FakeModule
|
from ..misc.fakeModule import FakeModule
|
||||||
|
from ..misc.distTools import new_dist_like
|
||||||
|
|
||||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||||
# TODO: Support Squashed Dists (tanh)
|
# TODO: Support Squashed Dists (tanh)
|
||||||
@ -96,6 +97,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
self.distribution = None
|
self.distribution = None
|
||||||
|
|
||||||
|
def new_dist_like_me(self, mean, pseudo_chol):
|
||||||
|
p = self.distribution
|
||||||
|
np = new_dist_like(p, mean, pseudo_chol)
|
||||||
|
new = UniversalGaussianDistribution(self.action_dim)
|
||||||
|
new.par_strength = self.par_strength
|
||||||
|
new.cov_strength = self.cov_strength
|
||||||
|
new.par_type = self.par_type
|
||||||
|
new.enforce_positive_type = self.enforce_positive_type
|
||||||
|
new.prob_squashing_type = self.prob_squashing_type
|
||||||
|
new.distribution = np
|
||||||
|
|
||||||
|
return new
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
|
||||||
"""
|
"""
|
||||||
Create the layers and parameter that represent the distribution:
|
Create the layers and parameter that represent the distribution:
|
||||||
|
@ -2,6 +2,8 @@ import torch as th
|
|||||||
|
|
||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||||
|
|
||||||
|
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_chol(p, expand=False):
|
def get_mean_and_chol(p, expand=False):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
@ -64,7 +66,9 @@ def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
|||||||
|
|
||||||
|
|
||||||
def new_dist_like(orig_p, mean, chol):
|
def new_dist_like(orig_p, mean, chol):
|
||||||
if isinstance(orig_p, th.distributions.Normal):
|
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||||
|
return orig_p.new_list_like_me(mean, chol)
|
||||||
|
elif isinstance(orig_p, th.distributions.Normal):
|
||||||
if orig_p.stddev.shape != chol.shape:
|
if orig_p.stddev.shape != chol.shape:
|
||||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||||
return th.distributions.Normal(mean, chol)
|
return th.distributions.Normal(mean, chol)
|
||||||
|
Loading…
Reference in New Issue
Block a user