Fix: SAC Dists dont expect trajectory param
This commit is contained in:
parent
21e8a59770
commit
1cc3590462
@ -946,7 +946,10 @@ class Actor(BasePolicy):
|
|||||||
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# Note: the action is squashed
|
# Note: the action is squashed
|
||||||
|
if isinstance(self.action_dist, PCA_Distribution):
|
||||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||||
|
|
||||||
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
|
Loading…
Reference in New Issue
Block a user