Fixed bug with SDE

This commit is contained in:
Dominik Moritz Roth 2022-08-22 13:36:17 +02:00
parent a9e3f295b2
commit 197de7997c

View File

@ -365,7 +365,12 @@ class UniversalGaussianDistribution(SB3_Distribution):
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
chol = th.diag_embed(self.distribution.stddev) chol = th.diag_embed(self.distribution.stddev)
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
chol = self.distribution.scale_tril p = self.distribution
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
chol = p.stddev
elif isinstance(p, th.distributions.MultivariateNormal):
chol = p.scale_tril
# Use batch matrix multiplication for efficient computation # Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features) # (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1) latent_sde = latent_sde.unsqueeze(dim=1)