Implement Buffer that stores Distribution params

This commit is contained in:
Dominik Moritz Roth 2024-01-16 15:12:25 +01:00
parent 5ed5d32083
commit 1fa66611a3
3 changed files with 452 additions and 23 deletions

View File

@ -0,0 +1,386 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.type_aliases import (
DictRolloutBufferSamples,
RolloutBufferSamples,
TensorDict
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize
try:
# Check memory used by replay buffer when possible
import psutil
except ImportError:
psutil = None
from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
class BetterRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
mean: th.Tensor
cov_Decomp: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class BetterDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
mean: th.Tensor
cov_Decomp: th.Tensor
advantages: th.Tensor
returns: th.Tensor
class BetterRolloutBuffer(RolloutBuffer):
"""
Extended to also save the mean and cov decomp.
Rollout buffer used in on-policy algorithms like A2C/PPO.
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray
def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()
def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the lambda-return (TD(lambda) estimate)
and GAE(lambda) advantage.
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
where R is the sum of discounted reward with value bootstrap
(because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
The TD(lambda) estimator has also two special cases:
- TD(1) is Monte-Carlo estimate (sum of discounted rewards)
- TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
:param last_values: state value estimation for the last step (one for each env)
:param dones: if the last step was a terminal step (one bool for each env).
"""
# Convert to numpy
last_values = last_values.clone().cpu().numpy().flatten()
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
next_values = self.values[step + 1]
delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
self.advantages[step] = last_gae_lam
# TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
# in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
self.returns = self.advantages + self.values
def add(
self,
obs: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
mean: th.Tensor,
cov_decomp: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space, spaces.Discrete):
obs = obs.reshape((self.n_envs, *self.obs_shape))
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.observations[self.pos] = np.array(obs)
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.means[self.pos] = mean.clone().cpu().numpy()
self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
_tensor_names = [
"observations",
"actions",
"values",
"log_probs",
"advantages",
"returns",
]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples(
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> BetterRolloutBufferSamples:
data = (
self.observations[batch_inds],
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.means[batch_inds].flatten(),
self.cov_decomps[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data)))
class BetterDictRolloutBuffer(DictRolloutBuffer):
"""
Extended to also save the mean and cov decomp.
Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
Extends the RolloutBuffer to use dictionary observations
It corresponds to ``buffer_size`` transitions collected
using the current policy.
This experience will be discarded after the policy update.
In order to use PPO objective, we also store the current value of each state
and the log probability of each taken action.
The term rollout here refers to the model-free notion and should not
be used with the concept of rollout used in model-based RL or planning.
Hence, it is only involved in policy and value function training but not action selection.
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
:param n_envs: Number of parallel environments
"""
observation_space: spaces.Dict
obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment]
observations: Dict[str, np.ndarray] # type: ignore[assignment]
def __init__(
self,
buffer_size: int,
observation_space: spaces.Dict,
action_space: spaces.Space,
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
):
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
self.gae_lambda = gae_lambda
self.gamma = gamma
self.generator_ready = False
self.reset()
def reset(self) -> None:
self.observations = {}
for key, obs_input_shape in self.obs_shape.items():
self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super(RolloutBuffer, self).reset()
def add( # type: ignore[override]
self,
obs: Dict[str, np.ndarray],
action: np.ndarray,
reward: np.ndarray,
episode_start: np.ndarray,
value: th.Tensor,
log_prob: th.Tensor,
) -> None:
"""
:param obs: Observation
:param action: Action
:param reward:
:param episode_start: Start of episode signal.
:param value: estimated value of the current state
following the current policy.
:param log_prob: log probability of the action
following the current policy.
"""
if len(log_prob.shape) == 0:
# Reshape 0-d tensor to avoid error
log_prob = log_prob.reshape(-1, 1)
for key in self.observations.keys():
obs_ = np.array(obs[key])
# Reshape needed when using multiple envs with discrete observations
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
if isinstance(self.observation_space.spaces[key], spaces.Discrete):
obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
self.observations[key][self.pos] = obs_
# Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action)
self.rewards[self.pos] = np.array(reward)
self.episode_starts[self.pos] = np.array(episode_start)
self.values[self.pos] = value.clone().cpu().numpy().flatten()
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
self.pos += 1
if self.pos == self.buffer_size:
self.full = True
def get( # type: ignore[override]
self,
batch_size: Optional[int] = None,
) -> Generator[DictRolloutBufferSamples, None, None]:
assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
for key, obs in self.observations.items():
self.observations[key] = self.swap_and_flatten(obs)
_tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
for tensor in _tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
# Return everything, don't create minibatches
if batch_size is None:
batch_size = self.buffer_size * self.n_envs
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
def _get_samples( # type: ignore[override]
self,
batch_inds: np.ndarray,
env: Optional[VecNormalize] = None,
) -> BetterDictRolloutBufferSamples:
return BetterDictRolloutBufferSamples(
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
actions=self.to_torch(self.actions[batch_inds]),
old_values=self.to_torch(self.values[batch_inds].flatten()),
old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
mean=self.to_torch(self.means[batch_inds].flatten()),
cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()),
advantages=self.to_torch(self.advantages[batch_inds].flatten()),
returns=self.to_torch(self.returns[batch_inds].flatten()),
)

View File

@ -8,6 +8,7 @@ from gymnasium import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
@ -70,6 +71,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
use_sde: bool,
sde_sample_freq: int,
use_pca: bool,
pca_is: bool,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
@ -85,6 +87,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
assert not (use_sde and use_pca)
self.use_pca = use_pca
assert not pca_is or use_pca
self.pca_is = pca_is
assert not rollout_buffer_class and not rollout_buffer_kwargs
super().__init__(
@ -122,9 +127,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
self.rollout_buffer_class = BetterDictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer
self.rollout_buffer_class = BetterRolloutBuffer
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
@ -186,7 +191,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
actions, values, log_probs = self.policy(obs_tensor)
actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is)
actions = actions.cpu().numpy()
# Rescale and perform action
@ -231,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
self._last_obs = new_obs
self._last_episode_starts = dones
@ -254,6 +259,30 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
"""
raise NotImplementedError
def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory: th.Tensor = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:param trajectory: Past trajectory. Only required when using PCA.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
def learn(
self: SelfOnPolicyAlgorithm,
total_timesteps: int,

View File

@ -34,9 +34,12 @@ from stable_baselines3.common.torch_layers import (
from stable_baselines3.common.policies import ContinuousCritic
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer
from .distributions import make_proba_distribution
from metastable_baselines2.common.pca import PCA_Distribution
@ -395,7 +398,7 @@ class BasePolicy(BaseModel, ABC):
class ActorCriticPolicy(BasePolicy):
"""
Policy class for actor-critic algorithms (has both policy and value prediction).
Used by A2C, PPO and the likes.
Used by A2C, PPO, TRPL and the likes.
:param observation_space: Observation space
:param action_space: Action space
@ -445,6 +448,7 @@ class ActorCriticPolicy(BasePolicy):
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
dist_kwargs: Optional[Dict[str, Any]] = {},
policy_projection: BaseProjectionLayer = IdentityProjectionLayer(),
):
if optimizer_kwargs is None:
optimizer_kwargs = {}
@ -514,6 +518,8 @@ class ActorCriticPolicy(BasePolicy):
self.use_pca = use_pca
self.dist_kwargs = dist_kwargs
self.policy_projection = policy_projection
# Action distribution
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
@ -624,7 +630,7 @@ class ActorCriticPolicy(BasePolicy):
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
@ -644,9 +650,13 @@ class ActorCriticPolicy(BasePolicy):
values = self.value_net(latent_vf)
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
if conditioned_log_probs:
assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.'
log_prob = distribution.conditioned_log_prob(actions)
else:
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
return actions, values, log_prob
return actions, values, log_prob, distribution
def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
@ -691,7 +701,7 @@ class ActorCriticPolicy(BasePolicy):
else:
raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
@ -699,19 +709,23 @@ class ActorCriticPolicy(BasePolicy):
:param deterministic: Whether to use stochastic or deterministic actions
:return: Taken action according to the policy
"""
if self.use_pca:
return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
else:
return self.get_distribution(observation).get_actions(deterministic=deterministic)
def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
:param obs: Observation
:param rollout_data: The Rollouts (containing )
:param actions: Actions
:return: estimated value, log likelihood of taking those actions
and entropy of the action distribution.
"""
# Preprocess the observation if needed
obs = rollout_data.observations
features = self.extract_features(obs)
if self.share_features_extractor:
latent_pi, latent_vf = self.mlp_extractor(features)
@ -719,11 +733,13 @@ class ActorCriticPolicy(BasePolicy):
pi_features, vf_features = features
latent_pi = self.mlp_extractor.forward_actor(pi_features)
latent_vf = self.mlp_extractor.forward_critic(vf_features)
distribution = self._get_action_dist_from_latent(latent_pi)
raw_distribution = self._get_action_dist_from_latent(latent_pi)
distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data)
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
return values, log_prob, entropy
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
return values, log_prob, entropy, trust_region_loss
def get_distribution(self, obs: th.Tensor) -> Distribution:
"""
@ -918,18 +934,18 @@ class Actor(BasePolicy):
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic)
def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
return self(observation, deterministic, trajectory=trajectory)
class SACPolicy(BasePolicy):
@ -1018,16 +1034,14 @@ class SACPolicy(BasePolicy):
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
self.actor_kwargs.update({
"use_sde": use_sde,
"use_pca": use_pca,
"log_std_init": log_std_init,
"use_expln": use_expln,
"clip_mean": clip_mean,
"dist_kwargs": dist_kwargs,
}
self.actor_kwargs.update(sde_kwargs)
})
self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update(
{