Support passing trajectories into predictiion functions (for pca)
This commit is contained in:
parent
d4003e1a68
commit
9cd25ee484
@ -280,4 +280,4 @@ class GaussianRolloutCollectorAuxclass():
|
||||
|
||||
def get_past_trajectories(self) -> th.Tensor:
|
||||
# TODO: Respect Episode Boundaries
|
||||
return th.Tensor(self.rollout_buffer.actions)
|
||||
return np.swapaxes(th.Tensor(self.rollout_buffer.actions[:self.rollout_buffer.pos]), 0, 1)
|
||||
|
@ -359,7 +359,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
else:
|
||||
raise ValueError("Invalid action distribution")
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
@ -367,7 +367,10 @@ class ActorCriticPolicy(BasePolicy):
|
||||
:param deterministic: Whether to use stochastic or deterministic actions
|
||||
:return: Taken action according to the policy
|
||||
"""
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
if self.use_pca:
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
|
||||
else:
|
||||
return self.get_distribution(observation).get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
@ -409,6 +412,57 @@ class ActorCriticPolicy(BasePolicy):
|
||||
latent_vf = self.mlp_extractor.forward_critic(features)
|
||||
return self.value_net(latent_vf)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
trajectory: th.Tensor = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
# TODO (GH/1): add support for RNN policies
|
||||
# if state is None:
|
||||
# state = self.initial_state
|
||||
# if episode_start is None:
|
||||
# episode_start = [False for _ in range(self.n_envs)]
|
||||
# Switch to eval mode (this affects batch norm / dropout)
|
||||
self.set_training_mode(False)
|
||||
|
||||
observation, vectorized_env = self.obs_to_tensor(observation)
|
||||
|
||||
with th.no_grad():
|
||||
actions = self._predict(observation, deterministic=deterministic, trajectory=trajectory)
|
||||
# Convert to numpy, and reshape to the original action shape
|
||||
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
# Remove batch dimension if needed
|
||||
if not vectorized_env:
|
||||
actions = actions.squeeze(axis=0)
|
||||
|
||||
return actions, state
|
||||
|
||||
|
||||
class ActorCriticCnnPolicy(ActorCriticPolicy):
|
||||
"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
@ -422,6 +422,32 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
trajectory: th.Tensor = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last hidden states (can be None, used in recurrent policies)
|
||||
:param episode_start: The last masks (can be None, used in recurrent policies)
|
||||
this correspond to beginning of episodes,
|
||||
where the hidden states of the RNN must be reset.
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next hidden state
|
||||
(used in recurrent policies)
|
||||
"""
|
||||
if self.policy.use_pca:
|
||||
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
|
||||
else:
|
||||
return self.policy.predict(observation, state, episode_start, deterministic)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
|
Loading…
Reference in New Issue
Block a user