Support passing trajectories into predictiion functions (for pca)

This commit is contained in:
Dominik Moritz Roth 2023-05-21 20:16:25 +02:00
parent d4003e1a68
commit 9cd25ee484
3 changed files with 84 additions and 4 deletions

View File

@ -280,4 +280,4 @@ class GaussianRolloutCollectorAuxclass():
def get_past_trajectories(self) -> th.Tensor: def get_past_trajectories(self) -> th.Tensor:
# TODO: Respect Episode Boundaries # TODO: Respect Episode Boundaries
return th.Tensor(self.rollout_buffer.actions) return np.swapaxes(th.Tensor(self.rollout_buffer.actions[:self.rollout_buffer.pos]), 0, 1)

View File

@ -359,7 +359,7 @@ class ActorCriticPolicy(BasePolicy):
else: else:
raise ValueError("Invalid action distribution") raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
""" """
Get the action according to the policy for a given observation. Get the action according to the policy for a given observation.
@ -367,7 +367,10 @@ class ActorCriticPolicy(BasePolicy):
:param deterministic: Whether to use stochastic or deterministic actions :param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy :return: Taken action according to the policy
""" """
return self.get_distribution(observation).get_actions(deterministic=deterministic) if self.use_pca:
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
else:
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
""" """
@ -409,6 +412,57 @@ class ActorCriticPolicy(BasePolicy):
latent_vf = self.mlp_extractor.forward_critic(features) latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf) return self.value_net(latent_vf)
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic, trajectory=trajectory)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
# Remove batch dimension if needed
if not vectorized_env:
actions = actions.squeeze(axis=0)
return actions, state
class ActorCriticCnnPolicy(ActorCriticPolicy): class ActorCriticCnnPolicy(ActorCriticPolicy):
""" """

View File

@ -1,5 +1,5 @@
import warnings import warnings
from typing import Any, Dict, Optional, Type, Union, NamedTuple from typing import Any, Dict, Optional, Type, Union, NamedTuple, Tuple
import numpy as np import numpy as np
import torch as th import torch as th
@ -422,6 +422,32 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
if self.policy.use_pca:
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
else:
return self.policy.predict(observation, state, episode_start, deterministic)
def learn( def learn(
self, self,
total_timesteps: int, total_timesteps: int,