Fix: SAC Dists dont expect trajectory param
This commit is contained in:
parent
21e8a59770
commit
1cc3590462
@ -946,7 +946,10 @@ class Actor(BasePolicy):
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
# Note: the action is squashed
|
||||
if isinstance(self.action_dist, PCA_Distribution):
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
||||
else:
|
||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||
|
||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||
|
Loading…
Reference in New Issue
Block a user