diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 18e9428..f91b54c 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -946,7 +946,10 @@ class Actor(BasePolicy): def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed - return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) + if isinstance(self.action_dist, PCA_Distribution): + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) + else: + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs)