diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index dd8d1aa..db3d2a1 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -365,7 +365,12 @@ class UniversalGaussianDistribution(SB3_Distribution): if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): chol = th.diag_embed(self.distribution.stddev) return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] - chol = self.distribution.scale_tril + p = self.distribution + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): + chol = p.stddev + elif isinstance(p, th.distributions.MultivariateNormal): + chol = p.scale_tril + # Use batch matrix multiplication for efficient computation # (batch_size, n_features) -> (batch_size, 1, n_features) latent_sde = latent_sde.unsqueeze(dim=1)