From 82c6674615610d1b9112f6d0f088e751589858a9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 22 Aug 2023 00:05:04 +0200 Subject: [PATCH] PCA did not allow cont cov --- sbBrix/common/policies.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index 8d717cc..40c226b 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -581,7 +581,7 @@ class ActorCriticPolicy(BasePolicy): latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init ) elif isinstance(self.action_dist, PCA_Distribution): - self.action_net, self.log_std = self.action_dist.proba_distribution_net( + self.action_net, self.std_net = self.action_dist.proba_distribution_net( latent_dim=latent_dim_pi ) elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): @@ -677,7 +677,8 @@ class ActorCriticPolicy(BasePolicy): elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) elif isinstance(self.action_dist, PCA_Distribution): - return self.action_dist.proba_distribution(mean_actions, th.ones_like(mean_actions) * self.log_std.exp()) + std_actions = self.std_net(latent_pi) + return self.action_dist.proba_distribution(mean_actions, std_actions) else: raise ValueError("Invalid action distribution")