Fix: Typos for sde/rex hybrid

This commit is contained in:
Dominik Moritz Roth 2023-02-18 12:42:07 +01:00
parent c62723bef6
commit 75feefbe5a

View File

@ -159,9 +159,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.learn_features = sde_learn_features
self.sde_latent_softmax = sde_latent_softmax
if self.use_hybrid and not self.use_sde:
print('[!] use_hybrid forces use_sde to be true')
self.use_sde = True
if self.use_hybrid:
assert self.use_sde, 'use_sde has to be set to use use_hybrid'
assert (self.par_type != ParametrizationType.NONE) == (
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
@ -318,13 +317,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
return self.prob_squashing_type.apply(sample)
def _sample_sde(self) -> th.Tensor:
# More Reparametrization trick to pass gradients
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
self.gaussian_actions = actions
return self.prob_squashing_type.apply(actions)
def _sample_hybrid(self) -> th.Tensor:
f = self.hybrid_rex_factor
f = self.hybrid_rex_fac
actions = self._sample_normal()*f + self._sample_sde()*(1-f)
return self.prob_squashing_type.apply(actions)