diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index a76886e..d42dc7f 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -307,7 +307,10 @@ class ActorCriticPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) elif isinstance(self.action_dist, UniversalGaussianDistribution): if self.sqrt_induced_gaussian: - cov_sqrt = self.chol_net(latent_pi) + chol_sqrt_cov = self.chol_net(latent_pi) + if len(chol_sqrt_cov.shape) == 2: + chol_sqrt_cov = th.diag_embed(chol_sqrt_cov) + cov_sqrt = th.bmm(chol_sqrt_cov.mT, chol_sqrt_cov) dist = self.action_dist.proba_distribution_from_sqrt( mean_actions, cov_sqrt, latent_pi) mean, chol = get_mean_and_chol(dist, expand=False)