From bb1f9ecf2b058184dc81d2380cf591e42906c0b1 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 14 Aug 2022 18:42:19 +0200 Subject: [PATCH] Fixed UniversalGaussianDistribution lost SDE when cloning --- metastable_baselines/distributions/distributions.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index afa8d86..3d477d9 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -143,7 +143,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6, sde_learn_features=False, full_sde=None): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6, sde_learn_features=False): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim self.par_strength = cast_to_enum(neural_strength, Strength) @@ -163,8 +163,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.use_sde = use_sde self.learn_features = sde_learn_features - if full_sde != None: - print('[!] Argument full_sde is only provided to remain compatible with vanilla SB3 PPO. It does not serve any function!') + print('sde', self.use_sde) assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' @@ -181,8 +180,8 @@ class UniversalGaussianDistribution(SB3_Distribution): np = Independent(Normal(mean, chol), 1) elif isinstance(p, MultivariateNormal): np = MultivariateNormal(mean, scale_tril=chol) - new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, - parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) + new = UniversalGaussianDistribution(self.action_dim, use_sde=self.use_sde, neural_strength=self.par_strength, cov_strength=self.cov_strength, + parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type, epsilon=self.epsilon, sde_learn_features=self.learn_features) new.distribution = np return new