diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index b96e628..aae1d74 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -280,4 +280,4 @@ class GaussianRolloutCollectorAuxclass(): def get_past_trajectories(self) -> th.Tensor: # TODO: Respect Episode Boundaries - return th.Tensor(self.rollout_buffer.actions) + return np.swapaxes(th.Tensor(self.rollout_buffer.actions[:self.rollout_buffer.pos]), 0, 1) diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 94d13d0..bba4f36 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -359,7 +359,7 @@ class ActorCriticPolicy(BasePolicy): else: raise ValueError("Invalid action distribution") - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -367,7 +367,10 @@ class ActorCriticPolicy(BasePolicy): :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ - return self.get_distribution(observation).get_actions(deterministic=deterministic) + if self.use_pca: + return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory) + else: + return self.get_distribution(observation).get_actions(deterministic=deterministic) def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ @@ -409,6 +412,57 @@ class ActorCriticPolicy(BasePolicy): latent_vf = self.mlp_extractor.forward_critic(features) return self.value_net(latent_vf) + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + trajectory: th.Tensor = None, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if episode_start is None: + # episode_start = [False for _ in range(self.n_envs)] + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + with th.no_grad(): + actions = self._predict(observation, deterministic=deterministic, trajectory=trajectory) + # Convert to numpy, and reshape to the original action shape + actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape) + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions.squeeze(axis=0) + + return actions, state + class ActorCriticCnnPolicy(ActorCriticPolicy): """ diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index bfc2844..abd34cf 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Type, Union, NamedTuple +from typing import Any, Dict, Optional, Type, Union, NamedTuple, Tuple import numpy as np import torch as th @@ -422,6 +422,32 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf) + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + trajectory: th.Tensor = None, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + if self.policy.use_pca: + return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory) + else: + return self.policy.predict(observation, state, episode_start, deterministic) + def learn( self, total_timesteps: int,