diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py new file mode 100644 index 0000000..3b9aeb6 --- /dev/null +++ b/metastable_baselines2/common/buffers.py @@ -0,0 +1,386 @@ +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple + +import numpy as np +import torch as th +from gymnasium import spaces + +from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape +from stable_baselines3.common.type_aliases import ( + DictRolloutBufferSamples, + RolloutBufferSamples, + TensorDict +) +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import VecNormalize + +try: + # Check memory used by replay buffer when possible + import psutil +except ImportError: + psutil = None + +from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer + +class BetterRolloutBufferSamples(NamedTuple): + observations: th.Tensor + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + mean: th.Tensor + cov_Decomp: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + +class BetterDictRolloutBufferSamples(NamedTuple): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + mean: th.Tensor + cov_Decomp: th.Tensor + advantages: th.Tensor + returns: th.Tensor + +class BetterRolloutBuffer(RolloutBuffer): + """ + Extended to also save the mean and cov decomp. + + Rollout buffer used in on-policy algorithms like A2C/PPO. + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + advantages: np.ndarray + returns: np.ndarray + episode_starts: np.ndarray + log_probs: np.ndarray + values: np.ndarray + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + self.gae_lambda = gae_lambda + self.gamma = gamma + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super().reset() + + def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None: + """ + Post-processing step: compute the lambda-return (TD(lambda) estimate) + and GAE(lambda) advantage. + + Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) + to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S)) + where R is the sum of discounted reward with value bootstrap + (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization. + + The TD(lambda) estimator has also two special cases: + - TD(1) is Monte-Carlo estimate (sum of discounted rewards) + - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1})) + + For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375. + + :param last_values: state value estimation for the last step (one for each env) + :param dones: if the last step was a terminal step (one bool for each env). + """ + # Convert to numpy + last_values = last_values.clone().cpu().numpy().flatten() + + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_values = last_values + else: + next_non_terminal = 1.0 - self.episode_starts[step + 1] + next_values = self.values[step + 1] + delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + self.advantages[step] = last_gae_lam + # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" + # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA + self.returns = self.advantages + self.values + + def add( + self, + obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + mean: th.Tensor, + cov_decomp: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space, spaces.Discrete): + obs = obs.reshape((self.n_envs, *self.obs_shape)) + + # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 + action = action.reshape((self.n_envs, self.action_dim)) + + self.observations[self.pos] = np.array(obs) + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.episode_starts[self.pos] = np.array(episode_start) + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.means[self.pos] = mean.clone().cpu().numpy() + self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + _tensor_names = [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> BetterRolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.means[batch_inds].flatten(), + self.cov_decomps[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) + return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class BetterDictRolloutBuffer(DictRolloutBuffer): + """ + Extended to also save the mean and cov decomp. + + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to Monte-Carlo advantage estimate when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + observation_space: spaces.Dict + obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] + observations: Dict[str, np.ndarray] # type: ignore[assignment] + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Dict, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( # type: ignore[override] + self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]) + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 + action = action.reshape((self.n_envs, self.action_dim)) + + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.episode_starts[self.pos] = np.array(episode_start) + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get( # type: ignore[override] + self, + batch_size: Optional[int] = None, + ) -> Generator[DictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples( # type: ignore[override] + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> BetterDictRolloutBufferSamples: + return BetterDictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + mean=self.to_torch(self.means[batch_inds].flatten()), + cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) \ No newline at end of file diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index c2f9317..0320526 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -8,6 +8,7 @@ from gymnasium import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer +from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -70,6 +71,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): use_sde: bool, sde_sample_freq: int, use_pca: bool, + pca_is: bool, rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, stats_window_size: int = 100, @@ -85,6 +87,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): assert not (use_sde and use_pca) self.use_pca = use_pca + assert not pca_is or use_pca + self.pca_is = pca_is + assert not rollout_buffer_class and not rollout_buffer_kwargs super().__init__( @@ -122,9 +127,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): if self.rollout_buffer_class is None: if isinstance(self.observation_space, spaces.Dict): - self.rollout_buffer_class = DictRolloutBuffer + self.rollout_buffer_class = BetterDictRolloutBuffer else: - self.rollout_buffer_class = RolloutBuffer + self.rollout_buffer_class = BetterRolloutBuffer self.rollout_buffer = self.rollout_buffer_class( self.n_steps, @@ -186,7 +191,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): with th.no_grad(): # Convert to pytorch tensor or to TensorDict obs_tensor = obs_as_tensor(self._last_obs, self.device) - actions, values, log_probs = self.policy(obs_tensor) + actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is) actions = actions.cpu().numpy() # Rescale and perform action @@ -231,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale) self._last_obs = new_obs self._last_episode_starts = dones @@ -254,6 +259,30 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): """ raise NotImplementedError + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + trajectory: th.Tensor = None, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last hidden states (can be None, used in recurrent policies) + :param episode_start: The last masks (can be None, used in recurrent policies) + this correspond to beginning of episodes, + where the hidden states of the RNN must be reset. + :param deterministic: Whether or not to return deterministic actions. + :param trajectory: Past trajectory. Only required when using PCA. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory) + def learn( self: SelfOnPolicyAlgorithm, total_timesteps: int, diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index e46ded6..9190d3f 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -34,9 +34,12 @@ from stable_baselines3.common.torch_layers import ( from stable_baselines3.common.policies import ContinuousCritic -from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor +from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer + + from .distributions import make_proba_distribution from metastable_baselines2.common.pca import PCA_Distribution @@ -395,7 +398,7 @@ class BasePolicy(BaseModel, ABC): class ActorCriticPolicy(BasePolicy): """ Policy class for actor-critic algorithms (has both policy and value prediction). - Used by A2C, PPO and the likes. + Used by A2C, PPO, TRPL and the likes. :param observation_space: Observation space :param action_space: Action space @@ -445,6 +448,7 @@ class ActorCriticPolicy(BasePolicy): optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, dist_kwargs: Optional[Dict[str, Any]] = {}, + policy_projection: BaseProjectionLayer = IdentityProjectionLayer(), ): if optimizer_kwargs is None: optimizer_kwargs = {} @@ -514,6 +518,8 @@ class ActorCriticPolicy(BasePolicy): self.use_pca = use_pca self.dist_kwargs = dist_kwargs + self.policy_projection = policy_projection + # Action distribution self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs) @@ -624,7 +630,7 @@ class ActorCriticPolicy(BasePolicy): # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) @@ -644,9 +650,13 @@ class ActorCriticPolicy(BasePolicy): values = self.value_net(latent_vf) distribution = self._get_action_dist_from_latent(latent_pi) actions = distribution.get_actions(deterministic=deterministic) - log_prob = distribution.log_prob(actions) + if conditioned_log_probs: + assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.' + log_prob = distribution.conditioned_log_prob(actions) + else: + log_prob = distribution.log_prob(actions) actions = actions.reshape((-1, *self.action_space.shape)) - return actions, values, log_prob + return actions, values, log_prob, distribution def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: """ @@ -691,7 +701,7 @@ class ActorCriticPolicy(BasePolicy): else: raise ValueError("Invalid action distribution") - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -699,19 +709,23 @@ class ActorCriticPolicy(BasePolicy): :param deterministic: Whether to use stochastic or deterministic actions :return: Taken action according to the policy """ - return self.get_distribution(observation).get_actions(deterministic=deterministic) + if self.use_pca: + return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory) + else: + return self.get_distribution(observation).get_actions(deterministic=deterministic) - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. - :param obs: Observation + :param rollout_data: The Rollouts (containing ) :param actions: Actions :return: estimated value, log likelihood of taking those actions and entropy of the action distribution. """ # Preprocess the observation if needed + obs = rollout_data.observations features = self.extract_features(obs) if self.share_features_extractor: latent_pi, latent_vf = self.mlp_extractor(features) @@ -719,11 +733,13 @@ class ActorCriticPolicy(BasePolicy): pi_features, vf_features = features latent_pi = self.mlp_extractor.forward_actor(pi_features) latent_vf = self.mlp_extractor.forward_critic(vf_features) - distribution = self._get_action_dist_from_latent(latent_pi) + raw_distribution = self._get_action_dist_from_latent(latent_pi) + distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data) log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) entropy = distribution.entropy() - return values, log_prob, entropy + trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution) + return values, log_prob, entropy, trust_region_loss def get_distribution(self, obs: th.Tensor) -> Distribution: """ @@ -918,18 +934,18 @@ class Actor(BasePolicy): log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) return mean_actions, log_std, {} - def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed - return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs) def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: - return self(observation, deterministic) + def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor: + return self(observation, deterministic, trajectory=trajectory) class SACPolicy(BasePolicy): @@ -1018,16 +1034,14 @@ class SACPolicy(BasePolicy): } self.actor_kwargs = self.net_args.copy() - sde_kwargs = { + self.actor_kwargs.update({ "use_sde": use_sde, "use_pca": use_pca, "log_std_init": log_std_init, "use_expln": use_expln, "clip_mean": clip_mean, "dist_kwargs": dist_kwargs, - } - - self.actor_kwargs.update(sde_kwargs) + }) self.critic_kwargs = self.net_args.copy() self.critic_kwargs.update( {