diff --git a/sbBrix/common/policies.py b/sbBrix/common/policies.py index df1c1cc..66de5c8 100644 --- a/sbBrix/common/policies.py +++ b/sbBrix/common/policies.py @@ -580,6 +580,10 @@ class ActorCriticPolicy(BasePolicy): self.action_net, self.log_std = self.action_dist.proba_distribution_net( latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init ) + elif isinstance(self.action_dist, PCA_Distribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi + ) elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) else: