Fix backward pass error for hybrid-method

This commit is contained in:
Dominik Moritz Roth 2023-03-13 20:44:09 +01:00
parent 75feefbe5a
commit 09159774d9

View File

@ -325,7 +325,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
def _sample_hybrid(self) -> th.Tensor:
f = self.hybrid_rex_fac
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
rex_sample = self.distribution.rsample()
noise = self.get_noise(self._latent_sde)
sde_sample = self.distribution.mean + noise
actions = rex_sample*f + sde_sample*(1-f)
self.gaussian_actions = actions
return self.prob_squashing_type.apply(actions)
def mode(self) -> th.Tensor: