From c6a58b15ddd36dc3d5d55bf61829f514b11f4750 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 22 Aug 2022 14:19:40 +0200 Subject: [PATCH] Fixing SDE bug --- metastable_baselines/distributions/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index db3d2a1..22831d4 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -367,7 +367,7 @@ class UniversalGaussianDistribution(SB3_Distribution): return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] p = self.distribution if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): - chol = p.stddev + chol = th.diag_embed(self.distribution.stddev) elif isinstance(p, th.distributions.MultivariateNormal): chol = p.scale_tril