Fixing SDE bug

This commit is contained in:
Dominik Moritz Roth 2022-08-22 14:19:40 +02:00
parent 197de7997c
commit c6a58b15dd

View File

@ -367,7 +367,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
p = self.distribution p = self.distribution
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
chol = p.stddev chol = th.diag_embed(self.distribution.stddev)
elif isinstance(p, th.distributions.MultivariateNormal): elif isinstance(p, th.distributions.MultivariateNormal):
chol = p.scale_tril chol = p.scale_tril