From 09159774d9599c98b0d950d3a5471c280e3a8764 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 13 Mar 2023 20:44:09 +0100 Subject: [PATCH] Fix backward pass error for hybrid-method --- metastable_baselines/distributions/distributions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index b400c64..cc515d6 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -325,7 +325,14 @@ class UniversalGaussianDistribution(SB3_Distribution): def _sample_hybrid(self) -> th.Tensor: f = self.hybrid_rex_fac - actions = self._sample_normal()*f + self._sample_sde()*(1-f) + + rex_sample = self.distribution.rsample() + + noise = self.get_noise(self._latent_sde) + sde_sample = self.distribution.mean + noise + + actions = rex_sample*f + sde_sample*(1-f) + self.gaussian_actions = actions return self.prob_squashing_type.apply(actions) def mode(self) -> th.Tensor: