diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index a8761c8..d656542 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -351,7 +351,7 @@ class UniversalGaussianDistribution(SB3_Distribution): latent_sde = latent_sde if self.learn_features else latent_sde.detach() latent_sde = latent_sde[..., -self.latent_sde_dim:] if self.sde_latent_softmax: - latent_sde = th.softmax(dim=-1) + latent_sde = latent_sde.softmax(-1) latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) # Default case: only one exploration matrix if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):