diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 22dc522..b400c64 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -159,9 +159,8 @@ class UniversalGaussianDistribution(SB3_Distribution): self.learn_features = sde_learn_features self.sde_latent_softmax = sde_latent_softmax - if self.use_hybrid and not self.use_sde: - print('[!] use_hybrid forces use_sde to be true') - self.use_sde = True + if self.use_hybrid: + assert self.use_sde, 'use_sde has to be set to use use_hybrid' assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' @@ -318,13 +317,14 @@ class UniversalGaussianDistribution(SB3_Distribution): return self.prob_squashing_type.apply(sample) def _sample_sde(self) -> th.Tensor: + # More Reparametrization trick to pass gradients noise = self.get_noise(self._latent_sde) actions = self.distribution.mean + noise self.gaussian_actions = actions return self.prob_squashing_type.apply(actions) def _sample_hybrid(self) -> th.Tensor: - f = self.hybrid_rex_factor + f = self.hybrid_rex_fac actions = self._sample_normal()*f + self._sample_sde()*(1-f) return self.prob_squashing_type.apply(actions)