Fix: SAC Dists dont expect trajectory param

This commit is contained in:
Dominik Moritz Roth 2024-07-16 09:00:24 +02:00
parent 21e8a59770
commit 1cc3590462

View File

@ -946,7 +946,10 @@ class Actor(BasePolicy):
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs) mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed # Note: the action is squashed
if isinstance(self.action_dist, PCA_Distribution):
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
else:
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs) mean_actions, log_std, kwargs = self.get_action_dist_params(obs)