From 2c1689fbbc2cd5d1a03a93cece18aa8d17b77fe1 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 6 Aug 2022 14:36:35 +0200 Subject: [PATCH] Fixed smol bugs when instanciating based on string-names of enums --- metastable_baselines/distributions/distributions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 9752560..bff2443 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -116,10 +116,12 @@ def make_proba_distribution( if dist_kwargs is None: dist_kwargs = {} + dist_kwargs['use_sde'] = use_sde + if isinstance(action_space, gym.spaces.Box): assert len( action_space.shape) == 1, "Error: the action space must be a vector" - return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs) + return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, gym.spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, gym.spaces.MultiDiscrete): @@ -151,7 +153,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.enforce_positive_type = cast_to_enum( enforce_positive_type, EnforcePositiveType) self.prob_squashing_type = cast_to_enum( - prob_squashing_type, EnforcePositiveType) + prob_squashing_type, ProbSquashingType) self.epsilon = epsilon @@ -161,8 +163,8 @@ class UniversalGaussianDistribution(SB3_Distribution): if use_sde: raise Exception('SDE is not yet implemented') - assert (parameterization_type != ParametrizationType.NONE) == ( - cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' + assert (self.par_type != ParametrizationType.NONE) == ( + self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE: raise Exception(