Fix backward pass error for hybrid-method
This commit is contained in:
parent
75feefbe5a
commit
09159774d9
@ -325,7 +325,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
def _sample_hybrid(self) -> th.Tensor:
|
||||
f = self.hybrid_rex_fac
|
||||
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
|
||||
|
||||
rex_sample = self.distribution.rsample()
|
||||
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
sde_sample = self.distribution.mean + noise
|
||||
|
||||
actions = rex_sample*f + sde_sample*(1-f)
|
||||
self.gaussian_actions = actions
|
||||
return self.prob_squashing_type.apply(actions)
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
|
Loading…
Reference in New Issue
Block a user