Fix: Typos for sde/rex hybrid
This commit is contained in:
parent
c62723bef6
commit
75feefbe5a
@ -159,9 +159,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
self.learn_features = sde_learn_features
|
||||
self.sde_latent_softmax = sde_latent_softmax
|
||||
|
||||
if self.use_hybrid and not self.use_sde:
|
||||
print('[!] use_hybrid forces use_sde to be true')
|
||||
self.use_sde = True
|
||||
if self.use_hybrid:
|
||||
assert self.use_sde, 'use_sde has to be set to use use_hybrid'
|
||||
|
||||
assert (self.par_type != ParametrizationType.NONE) == (
|
||||
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||
@ -318,13 +317,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
return self.prob_squashing_type.apply(sample)
|
||||
|
||||
def _sample_sde(self) -> th.Tensor:
|
||||
# More Reparametrization trick to pass gradients
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
actions = self.distribution.mean + noise
|
||||
self.gaussian_actions = actions
|
||||
return self.prob_squashing_type.apply(actions)
|
||||
|
||||
def _sample_hybrid(self) -> th.Tensor:
|
||||
f = self.hybrid_rex_factor
|
||||
f = self.hybrid_rex_fac
|
||||
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
|
||||
return self.prob_squashing_type.apply(actions)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user