diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index b400c64..cc515d6 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -325,7 +325,14 @@ class UniversalGaussianDistribution(SB3_Distribution): def _sample_hybrid(self) -> th.Tensor: f = self.hybrid_rex_fac - actions = self._sample_normal()*f + self._sample_sde()*(1-f) + + rex_sample = self.distribution.rsample() + + noise = self.get_noise(self._latent_sde) + sde_sample = self.distribution.mean + noise + + actions = rex_sample*f + sde_sample*(1-f) + self.gaussian_actions = actions return self.prob_squashing_type.apply(actions) def mode(self) -> th.Tensor: