Support passing trajectories into predictiion functions (for pca)

This commit is contained in:
Dominik Moritz Roth 2023-05-21 20:16:25 +02:00
parent d4003e1a68
commit 9cd25ee484
3 changed files with 84 additions and 4 deletions

View File

@ -280,4 +280,4 @@ class GaussianRolloutCollectorAuxclass():
def get_past_trajectories(self) -> th.Tensor:
# TODO: Respect Episode Boundaries
return th.Tensor(self.rollout_buffer.actions)
return np.swapaxes(th.Tensor(self.rollout_buffer.actions[:self.rollout_buffer.pos]), 0, 1)

View File

@ -359,7 +359,7 @@ class ActorCriticPolicy(BasePolicy):
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
@ -367,7 +367,10 @@ class ActorCriticPolicy(BasePolicy):
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)
if self.use_pca:
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
else:
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
@ -409,6 +412,57 @@ class ActorCriticPolicy(BasePolicy):
latent_vf = self.mlp_extractor.forward_critic(features)
return self.value_net(latent_vf)
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if episode_start is None:
# episode_start = [False for _ in range(self.n_envs)]
# Switch to eval mode (this affects batch norm / dropout)
self.set_training_mode(False)
observation, vectorized_env = self.obs_to_tensor(observation)
with th.no_grad():
actions = self._predict(observation, deterministic=deterministic, trajectory=trajectory)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
# Remove batch dimension if needed
if not vectorized_env:
actions = actions.squeeze(axis=0)
return actions, state
class ActorCriticCnnPolicy(ActorCriticPolicy):
"""

View File

@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, Optional, Type, Union, NamedTuple
from typing import Any, Dict, Optional, Type, Union, NamedTuple, Tuple
import numpy as np
import torch as th
@ -422,6 +422,32 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
if self.policy.use_pca:
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
else:
return self.policy.predict(observation, state, episode_start, deterministic)
def learn(
self,
total_timesteps: int,