From 75feefbe5a43d626bc75dce9001dddf9a91a0027 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 18 Feb 2023 12:42:07 +0100 Subject: [PATCH] Fix: Typos for sde/rex hybrid --- metastable_baselines/distributions/distributions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 22dc522..b400c64 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -159,9 +159,8 @@ class UniversalGaussianDistribution(SB3_Distribution): self.learn_features = sde_learn_features self.sde_latent_softmax = sde_latent_softmax - if self.use_hybrid and not self.use_sde: - print('[!] use_hybrid forces use_sde to be true') - self.use_sde = True + if self.use_hybrid: + assert self.use_sde, 'use_sde has to be set to use use_hybrid' assert (self.par_type != ParametrizationType.NONE) == ( self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' @@ -318,13 +317,14 @@ class UniversalGaussianDistribution(SB3_Distribution): return self.prob_squashing_type.apply(sample) def _sample_sde(self) -> th.Tensor: + # More Reparametrization trick to pass gradients noise = self.get_noise(self._latent_sde) actions = self.distribution.mean + noise self.gaussian_actions = actions return self.prob_squashing_type.apply(actions) def _sample_hybrid(self) -> th.Tensor: - f = self.hybrid_rex_factor + f = self.hybrid_rex_fac actions = self._sample_normal()*f + self._sample_sde()*(1-f) return self.prob_squashing_type.apply(actions)