Fix backward pass error for hybrid-method
This commit is contained in:
parent
75feefbe5a
commit
09159774d9
@ -325,7 +325,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
def _sample_hybrid(self) -> th.Tensor:
|
def _sample_hybrid(self) -> th.Tensor:
|
||||||
f = self.hybrid_rex_fac
|
f = self.hybrid_rex_fac
|
||||||
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
|
|
||||||
|
rex_sample = self.distribution.rsample()
|
||||||
|
|
||||||
|
noise = self.get_noise(self._latent_sde)
|
||||||
|
sde_sample = self.distribution.mean + noise
|
||||||
|
|
||||||
|
actions = rex_sample*f + sde_sample*(1-f)
|
||||||
|
self.gaussian_actions = actions
|
||||||
return self.prob_squashing_type.apply(actions)
|
return self.prob_squashing_type.apply(actions)
|
||||||
|
|
||||||
def mode(self) -> th.Tensor:
|
def mode(self) -> th.Tensor:
|
||||||
|
Loading…
Reference in New Issue
Block a user