PCA did not allow cont cov

This commit is contained in:
Dominik Moritz Roth 2023-08-22 00:05:04 +02:00
parent f3683afb86
commit 82c6674615

View File

@ -581,7 +581,7 @@ class ActorCriticPolicy(BasePolicy):
latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
)
elif isinstance(self.action_dist, PCA_Distribution):
self.action_net, self.log_std = self.action_dist.proba_distribution_net(
self.action_net, self.std_net = self.action_dist.proba_distribution_net(
latent_dim=latent_dim_pi
)
elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
@ -677,7 +677,8 @@ class ActorCriticPolicy(BasePolicy):
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
elif isinstance(self.action_dist, PCA_Distribution):
return self.action_dist.proba_distribution(mean_actions, th.ones_like(mean_actions) * self.log_std.exp())
std_actions = self.std_net(latent_pi)
return self.action_dist.proba_distribution(mean_actions, std_actions)
else:
raise ValueError("Invalid action distribution")