From 1cc35904625b050586200ca7c7d84511adb28176 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 16 Jul 2024 09:00:24 +0200 Subject: [PATCH] Fix: SAC Dists dont expect trajectory param --- metastable_baselines2/common/policies.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 18e9428..f91b54c 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -946,7 +946,10 @@ class Actor(BasePolicy): def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed - return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) + if isinstance(self.action_dist, PCA_Distribution): + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) + else: + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs)