From 12e422aec7addf71f08d42c579c37aca36d7e081 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 7 Aug 2022 18:04:40 +0200 Subject: [PATCH] Why does KL double free? --- metastable_baselines/ppo/policies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index a76886e..d42dc7f 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -307,7 +307,10 @@ class ActorCriticPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) elif isinstance(self.action_dist, UniversalGaussianDistribution): if self.sqrt_induced_gaussian: - cov_sqrt = self.chol_net(latent_pi) + chol_sqrt_cov = self.chol_net(latent_pi) + if len(chol_sqrt_cov.shape) == 2: + chol_sqrt_cov = th.diag_embed(chol_sqrt_cov) + cov_sqrt = th.bmm(chol_sqrt_cov.mT, chol_sqrt_cov) dist = self.action_dist.proba_distribution_from_sqrt( mean_actions, cov_sqrt, latent_pi) mean, chol = get_mean_and_chol(dist, expand=False)