diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 879bad2..90534cd 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -79,13 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng for cs in allowedCovStrength: if ps.value > cs.value: continue - for ept in allowedEPTs: - if cs == Strength.FULL: - for pt in allowedPTs: - if pt != ParametrizationType.NONE: - yield (ps, cs, ept, pt) - else: - yield (ps, cs, ept, ParametrizationType.NONE) + if cs == Strength.NONE: + yield (ps, cs, EnforcePositiveType.NONE, ParametrizationType.NONE) + else: + for ept in allowedEPTs: + if cs == Strength.FULL: + for pt in allowedPTs: + if pt != ParametrizationType.NONE: + yield (ps, cs, ept, pt) + else: + yield (ps, cs, ept, ParametrizationType.NONE) def make_proba_distribution( @@ -147,6 +150,10 @@ class UniversalGaussianDistribution(SB3_Distribution): assert (parameterization_type != ParametrizationType.NONE) == ( cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' + if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE: + raise Exception( + 'You need to specify an enforce_positive_type for spherical_cholesky') + def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor): p = self.distribution if isinstance(p, Independent): @@ -223,9 +230,9 @@ class UniversalGaussianDistribution(SB3_Distribution): def mode(self) -> th.Tensor: return self.distribution.mean - def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor: # Update the proba distribution - self.proba_distribution(mean_actions, log_std) + self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi) return self.get_actions(deterministic=deterministic) def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: