Implement Buffer that stores Distribution params
This commit is contained in:
		
							parent
							
								
									5ed5d32083
								
							
						
					
					
						commit
						1fa66611a3
					
				
							
								
								
									
										386
									
								
								metastable_baselines2/common/buffers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										386
									
								
								metastable_baselines2/common/buffers.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,386 @@
 | 
				
			|||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
 | 
					from typing import Any, Dict, Generator, List, Optional, Tuple, Union, NamedTuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch as th
 | 
				
			||||||
 | 
					from gymnasium import spaces
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
 | 
				
			||||||
 | 
					from stable_baselines3.common.type_aliases import (
 | 
				
			||||||
 | 
					    DictRolloutBufferSamples,
 | 
				
			||||||
 | 
					    RolloutBufferSamples,
 | 
				
			||||||
 | 
					    TensorDict
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from stable_baselines3.common.utils import get_device
 | 
				
			||||||
 | 
					from stable_baselines3.common.vec_env import VecNormalize
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    # Check memory used by replay buffer when possible
 | 
				
			||||||
 | 
					    import psutil
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    psutil = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from stable_baselines3.common.buffers import RolloutBuffer, DictRolloutBuffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BetterRolloutBufferSamples(NamedTuple):
 | 
				
			||||||
 | 
					    observations: th.Tensor
 | 
				
			||||||
 | 
					    actions: th.Tensor
 | 
				
			||||||
 | 
					    old_values: th.Tensor
 | 
				
			||||||
 | 
					    old_log_prob: th.Tensor
 | 
				
			||||||
 | 
					    mean: th.Tensor
 | 
				
			||||||
 | 
					    cov_Decomp: th.Tensor
 | 
				
			||||||
 | 
					    advantages: th.Tensor
 | 
				
			||||||
 | 
					    returns: th.Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BetterDictRolloutBufferSamples(NamedTuple):
 | 
				
			||||||
 | 
					    observations: TensorDict
 | 
				
			||||||
 | 
					    actions: th.Tensor
 | 
				
			||||||
 | 
					    old_values: th.Tensor
 | 
				
			||||||
 | 
					    old_log_prob: th.Tensor
 | 
				
			||||||
 | 
					    mean: th.Tensor
 | 
				
			||||||
 | 
					    cov_Decomp: th.Tensor
 | 
				
			||||||
 | 
					    advantages: th.Tensor
 | 
				
			||||||
 | 
					    returns: th.Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BetterRolloutBuffer(RolloutBuffer):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Extended to also save the mean and cov decomp.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Rollout buffer used in on-policy algorithms like A2C/PPO.
 | 
				
			||||||
 | 
					    It corresponds to ``buffer_size`` transitions collected
 | 
				
			||||||
 | 
					    using the current policy.
 | 
				
			||||||
 | 
					    This experience will be discarded after the policy update.
 | 
				
			||||||
 | 
					    In order to use PPO objective, we also store the current value of each state
 | 
				
			||||||
 | 
					    and the log probability of each taken action.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The term rollout here refers to the model-free notion and should not
 | 
				
			||||||
 | 
					    be used with the concept of rollout used in model-based RL or planning.
 | 
				
			||||||
 | 
					    Hence, it is only involved in policy and value function training but not action selection.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param buffer_size: Max number of element in the buffer
 | 
				
			||||||
 | 
					    :param observation_space: Observation space
 | 
				
			||||||
 | 
					    :param action_space: Action space
 | 
				
			||||||
 | 
					    :param device: PyTorch device
 | 
				
			||||||
 | 
					    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
 | 
				
			||||||
 | 
					        Equivalent to classic advantage when set to 1.
 | 
				
			||||||
 | 
					    :param gamma: Discount factor
 | 
				
			||||||
 | 
					    :param n_envs: Number of parallel environments
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    observations: np.ndarray
 | 
				
			||||||
 | 
					    actions: np.ndarray
 | 
				
			||||||
 | 
					    rewards: np.ndarray
 | 
				
			||||||
 | 
					    advantages: np.ndarray
 | 
				
			||||||
 | 
					    returns: np.ndarray
 | 
				
			||||||
 | 
					    episode_starts: np.ndarray
 | 
				
			||||||
 | 
					    log_probs: np.ndarray
 | 
				
			||||||
 | 
					    values: np.ndarray
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        buffer_size: int,
 | 
				
			||||||
 | 
					        observation_space: spaces.Space,
 | 
				
			||||||
 | 
					        action_space: spaces.Space,
 | 
				
			||||||
 | 
					        device: Union[th.device, str] = "auto",
 | 
				
			||||||
 | 
					        gae_lambda: float = 1,
 | 
				
			||||||
 | 
					        gamma: float = 0.99,
 | 
				
			||||||
 | 
					        n_envs: int = 1,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
 | 
				
			||||||
 | 
					        self.gae_lambda = gae_lambda
 | 
				
			||||||
 | 
					        self.gamma = gamma
 | 
				
			||||||
 | 
					        self.generator_ready = False
 | 
				
			||||||
 | 
					        self.reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reset(self) -> None:
 | 
				
			||||||
 | 
					        self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.generator_ready = False
 | 
				
			||||||
 | 
					        super().reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Post-processing step: compute the lambda-return (TD(lambda) estimate)
 | 
				
			||||||
 | 
					        and GAE(lambda) advantage.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
 | 
				
			||||||
 | 
					        to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
 | 
				
			||||||
 | 
					        where R is the sum of discounted reward with value bootstrap
 | 
				
			||||||
 | 
					        (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The TD(lambda) estimator has also two special cases:
 | 
				
			||||||
 | 
					        - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
 | 
				
			||||||
 | 
					        - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param last_values: state value estimation for the last step (one for each env)
 | 
				
			||||||
 | 
					        :param dones: if the last step was a terminal step (one bool for each env).
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Convert to numpy
 | 
				
			||||||
 | 
					        last_values = last_values.clone().cpu().numpy().flatten()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        last_gae_lam = 0
 | 
				
			||||||
 | 
					        for step in reversed(range(self.buffer_size)):
 | 
				
			||||||
 | 
					            if step == self.buffer_size - 1:
 | 
				
			||||||
 | 
					                next_non_terminal = 1.0 - dones
 | 
				
			||||||
 | 
					                next_values = last_values
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                next_non_terminal = 1.0 - self.episode_starts[step + 1]
 | 
				
			||||||
 | 
					                next_values = self.values[step + 1]
 | 
				
			||||||
 | 
					            delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
 | 
				
			||||||
 | 
					            last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
 | 
				
			||||||
 | 
					            self.advantages[step] = last_gae_lam
 | 
				
			||||||
 | 
					        # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
 | 
				
			||||||
 | 
					        # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
 | 
				
			||||||
 | 
					        self.returns = self.advantages + self.values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        obs: np.ndarray,
 | 
				
			||||||
 | 
					        action: np.ndarray,
 | 
				
			||||||
 | 
					        reward: np.ndarray,
 | 
				
			||||||
 | 
					        episode_start: np.ndarray,
 | 
				
			||||||
 | 
					        value: th.Tensor,
 | 
				
			||||||
 | 
					        log_prob: th.Tensor,
 | 
				
			||||||
 | 
					        mean: th.Tensor,
 | 
				
			||||||
 | 
					        cov_decomp: th.Tensor,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        :param obs: Observation
 | 
				
			||||||
 | 
					        :param action: Action
 | 
				
			||||||
 | 
					        :param reward:
 | 
				
			||||||
 | 
					        :param episode_start: Start of episode signal.
 | 
				
			||||||
 | 
					        :param value: estimated value of the current state
 | 
				
			||||||
 | 
					            following the current policy.
 | 
				
			||||||
 | 
					        :param log_prob: log probability of the action
 | 
				
			||||||
 | 
					            following the current policy.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if len(log_prob.shape) == 0:
 | 
				
			||||||
 | 
					            # Reshape 0-d tensor to avoid error
 | 
				
			||||||
 | 
					            log_prob = log_prob.reshape(-1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape needed when using multiple envs with discrete observations
 | 
				
			||||||
 | 
					        # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
 | 
				
			||||||
 | 
					        if isinstance(self.observation_space, spaces.Discrete):
 | 
				
			||||||
 | 
					            obs = obs.reshape((self.n_envs, *self.obs_shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
 | 
				
			||||||
 | 
					        action = action.reshape((self.n_envs, self.action_dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.observations[self.pos] = np.array(obs)
 | 
				
			||||||
 | 
					        self.actions[self.pos] = np.array(action)
 | 
				
			||||||
 | 
					        self.rewards[self.pos] = np.array(reward)
 | 
				
			||||||
 | 
					        self.episode_starts[self.pos] = np.array(episode_start)
 | 
				
			||||||
 | 
					        self.values[self.pos] = value.clone().cpu().numpy().flatten()
 | 
				
			||||||
 | 
					        self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
 | 
				
			||||||
 | 
					        self.means[self.pos] = mean.clone().cpu().numpy()
 | 
				
			||||||
 | 
					        self.cov_decomps[self.pos] = cov_decomp.clone().cpu().numpy()
 | 
				
			||||||
 | 
					        self.pos += 1
 | 
				
			||||||
 | 
					        if self.pos == self.buffer_size:
 | 
				
			||||||
 | 
					            self.full = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
 | 
				
			||||||
 | 
					        assert self.full, ""
 | 
				
			||||||
 | 
					        indices = np.random.permutation(self.buffer_size * self.n_envs)
 | 
				
			||||||
 | 
					        # Prepare the data
 | 
				
			||||||
 | 
					        if not self.generator_ready:
 | 
				
			||||||
 | 
					            _tensor_names = [
 | 
				
			||||||
 | 
					                "observations",
 | 
				
			||||||
 | 
					                "actions",
 | 
				
			||||||
 | 
					                "values",
 | 
				
			||||||
 | 
					                "log_probs",
 | 
				
			||||||
 | 
					                "advantages",
 | 
				
			||||||
 | 
					                "returns",
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for tensor in _tensor_names:
 | 
				
			||||||
 | 
					                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
 | 
				
			||||||
 | 
					            self.generator_ready = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Return everything, don't create minibatches
 | 
				
			||||||
 | 
					        if batch_size is None:
 | 
				
			||||||
 | 
					            batch_size = self.buffer_size * self.n_envs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_idx = 0
 | 
				
			||||||
 | 
					        while start_idx < self.buffer_size * self.n_envs:
 | 
				
			||||||
 | 
					            yield self._get_samples(indices[start_idx : start_idx + batch_size])
 | 
				
			||||||
 | 
					            start_idx += batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_samples(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_inds: np.ndarray,
 | 
				
			||||||
 | 
					        env: Optional[VecNormalize] = None,
 | 
				
			||||||
 | 
					    ) -> BetterRolloutBufferSamples:
 | 
				
			||||||
 | 
					        data = (
 | 
				
			||||||
 | 
					            self.observations[batch_inds],
 | 
				
			||||||
 | 
					            self.actions[batch_inds],
 | 
				
			||||||
 | 
					            self.values[batch_inds].flatten(),
 | 
				
			||||||
 | 
					            self.log_probs[batch_inds].flatten(),
 | 
				
			||||||
 | 
					            self.means[batch_inds].flatten(),
 | 
				
			||||||
 | 
					            self.cov_decomps[batch_inds].flatten(),
 | 
				
			||||||
 | 
					            self.advantages[batch_inds].flatten(),
 | 
				
			||||||
 | 
					            self.returns[batch_inds].flatten(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return BetterRolloutBufferSamples(*tuple(map(self.to_torch, data)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BetterDictRolloutBuffer(DictRolloutBuffer):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Extended to also save the mean and cov decomp.
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    Dict Rollout buffer used in on-policy algorithms like A2C/PPO.
 | 
				
			||||||
 | 
					    Extends the RolloutBuffer to use dictionary observations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    It corresponds to ``buffer_size`` transitions collected
 | 
				
			||||||
 | 
					    using the current policy.
 | 
				
			||||||
 | 
					    This experience will be discarded after the policy update.
 | 
				
			||||||
 | 
					    In order to use PPO objective, we also store the current value of each state
 | 
				
			||||||
 | 
					    and the log probability of each taken action.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The term rollout here refers to the model-free notion and should not
 | 
				
			||||||
 | 
					    be used with the concept of rollout used in model-based RL or planning.
 | 
				
			||||||
 | 
					    Hence, it is only involved in policy and value function training but not action selection.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    :param buffer_size: Max number of element in the buffer
 | 
				
			||||||
 | 
					    :param observation_space: Observation space
 | 
				
			||||||
 | 
					    :param action_space: Action space
 | 
				
			||||||
 | 
					    :param device: PyTorch device
 | 
				
			||||||
 | 
					    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
 | 
				
			||||||
 | 
					        Equivalent to Monte-Carlo advantage estimate when set to 1.
 | 
				
			||||||
 | 
					    :param gamma: Discount factor
 | 
				
			||||||
 | 
					    :param n_envs: Number of parallel environments
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    observation_space: spaces.Dict
 | 
				
			||||||
 | 
					    obs_shape: Dict[str, Tuple[int, ...]]  # type: ignore[assignment]
 | 
				
			||||||
 | 
					    observations: Dict[str, np.ndarray]  # type: ignore[assignment]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        buffer_size: int,
 | 
				
			||||||
 | 
					        observation_space: spaces.Dict,
 | 
				
			||||||
 | 
					        action_space: spaces.Space,
 | 
				
			||||||
 | 
					        device: Union[th.device, str] = "auto",
 | 
				
			||||||
 | 
					        gae_lambda: float = 1,
 | 
				
			||||||
 | 
					        gamma: float = 0.99,
 | 
				
			||||||
 | 
					        n_envs: int = 1,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.gae_lambda = gae_lambda
 | 
				
			||||||
 | 
					        self.gamma = gamma
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.generator_ready = False
 | 
				
			||||||
 | 
					        self.reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def reset(self) -> None:
 | 
				
			||||||
 | 
					        self.observations = {}
 | 
				
			||||||
 | 
					        for key, obs_input_shape in self.obs_shape.items():
 | 
				
			||||||
 | 
					            self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
 | 
				
			||||||
 | 
					        self.generator_ready = False
 | 
				
			||||||
 | 
					        super(RolloutBuffer, self).reset()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add(  # type: ignore[override]
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        obs: Dict[str, np.ndarray],
 | 
				
			||||||
 | 
					        action: np.ndarray,
 | 
				
			||||||
 | 
					        reward: np.ndarray,
 | 
				
			||||||
 | 
					        episode_start: np.ndarray,
 | 
				
			||||||
 | 
					        value: th.Tensor,
 | 
				
			||||||
 | 
					        log_prob: th.Tensor,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        :param obs: Observation
 | 
				
			||||||
 | 
					        :param action: Action
 | 
				
			||||||
 | 
					        :param reward:
 | 
				
			||||||
 | 
					        :param episode_start: Start of episode signal.
 | 
				
			||||||
 | 
					        :param value: estimated value of the current state
 | 
				
			||||||
 | 
					            following the current policy.
 | 
				
			||||||
 | 
					        :param log_prob: log probability of the action
 | 
				
			||||||
 | 
					            following the current policy.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if len(log_prob.shape) == 0:
 | 
				
			||||||
 | 
					            # Reshape 0-d tensor to avoid error
 | 
				
			||||||
 | 
					            log_prob = log_prob.reshape(-1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for key in self.observations.keys():
 | 
				
			||||||
 | 
					            obs_ = np.array(obs[key])
 | 
				
			||||||
 | 
					            # Reshape needed when using multiple envs with discrete observations
 | 
				
			||||||
 | 
					            # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
 | 
				
			||||||
 | 
					            if isinstance(self.observation_space.spaces[key], spaces.Discrete):
 | 
				
			||||||
 | 
					                obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key])
 | 
				
			||||||
 | 
					            self.observations[key][self.pos] = obs_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
 | 
				
			||||||
 | 
					        action = action.reshape((self.n_envs, self.action_dim))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.actions[self.pos] = np.array(action)
 | 
				
			||||||
 | 
					        self.rewards[self.pos] = np.array(reward)
 | 
				
			||||||
 | 
					        self.episode_starts[self.pos] = np.array(episode_start)
 | 
				
			||||||
 | 
					        self.values[self.pos] = value.clone().cpu().numpy().flatten()
 | 
				
			||||||
 | 
					        self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
 | 
				
			||||||
 | 
					        self.pos += 1
 | 
				
			||||||
 | 
					        if self.pos == self.buffer_size:
 | 
				
			||||||
 | 
					            self.full = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(  # type: ignore[override]
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_size: Optional[int] = None,
 | 
				
			||||||
 | 
					    ) -> Generator[DictRolloutBufferSamples, None, None]:
 | 
				
			||||||
 | 
					        assert self.full, ""
 | 
				
			||||||
 | 
					        indices = np.random.permutation(self.buffer_size * self.n_envs)
 | 
				
			||||||
 | 
					        # Prepare the data
 | 
				
			||||||
 | 
					        if not self.generator_ready:
 | 
				
			||||||
 | 
					            for key, obs in self.observations.items():
 | 
				
			||||||
 | 
					                self.observations[key] = self.swap_and_flatten(obs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for tensor in _tensor_names:
 | 
				
			||||||
 | 
					                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
 | 
				
			||||||
 | 
					            self.generator_ready = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Return everything, don't create minibatches
 | 
				
			||||||
 | 
					        if batch_size is None:
 | 
				
			||||||
 | 
					            batch_size = self.buffer_size * self.n_envs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_idx = 0
 | 
				
			||||||
 | 
					        while start_idx < self.buffer_size * self.n_envs:
 | 
				
			||||||
 | 
					            yield self._get_samples(indices[start_idx : start_idx + batch_size])
 | 
				
			||||||
 | 
					            start_idx += batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_samples(  # type: ignore[override]
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        batch_inds: np.ndarray,
 | 
				
			||||||
 | 
					        env: Optional[VecNormalize] = None,
 | 
				
			||||||
 | 
					    ) -> BetterDictRolloutBufferSamples:
 | 
				
			||||||
 | 
					        return BetterDictRolloutBufferSamples(
 | 
				
			||||||
 | 
					            observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
 | 
				
			||||||
 | 
					            actions=self.to_torch(self.actions[batch_inds]),
 | 
				
			||||||
 | 
					            old_values=self.to_torch(self.values[batch_inds].flatten()),
 | 
				
			||||||
 | 
					            old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()),
 | 
				
			||||||
 | 
					            mean=self.to_torch(self.means[batch_inds].flatten()),
 | 
				
			||||||
 | 
					            cov_decomp=self.to_torch(self.cov_decomps[batch_inds].flatten()),
 | 
				
			||||||
 | 
					            advantages=self.to_torch(self.advantages[batch_inds].flatten()),
 | 
				
			||||||
 | 
					            returns=self.to_torch(self.returns[batch_inds].flatten()),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
@ -8,6 +8,7 @@ from gymnasium import spaces
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from stable_baselines3.common.base_class import BaseAlgorithm
 | 
					from stable_baselines3.common.base_class import BaseAlgorithm
 | 
				
			||||||
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
 | 
					from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
 | 
				
			||||||
 | 
					from .buffers import BetterDictRolloutBuffer, BetterRolloutBuffer
 | 
				
			||||||
from stable_baselines3.common.callbacks import BaseCallback
 | 
					from stable_baselines3.common.callbacks import BaseCallback
 | 
				
			||||||
from stable_baselines3.common.policies import ActorCriticPolicy
 | 
					from stable_baselines3.common.policies import ActorCriticPolicy
 | 
				
			||||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
 | 
					from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
 | 
				
			||||||
@ -70,6 +71,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
        use_sde: bool,
 | 
					        use_sde: bool,
 | 
				
			||||||
        sde_sample_freq: int,
 | 
					        sde_sample_freq: int,
 | 
				
			||||||
        use_pca: bool,
 | 
					        use_pca: bool,
 | 
				
			||||||
 | 
					        pca_is: bool,
 | 
				
			||||||
        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
 | 
					        rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
 | 
				
			||||||
        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
 | 
					        rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
 | 
				
			||||||
        stats_window_size: int = 100,
 | 
					        stats_window_size: int = 100,
 | 
				
			||||||
@ -85,6 +87,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
        assert not (use_sde and use_pca)
 | 
					        assert not (use_sde and use_pca)
 | 
				
			||||||
        self.use_pca = use_pca
 | 
					        self.use_pca = use_pca
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert not pca_is or use_pca
 | 
				
			||||||
 | 
					        self.pca_is = pca_is
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert not rollout_buffer_class and not rollout_buffer_kwargs
 | 
					        assert not rollout_buffer_class and not rollout_buffer_kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super().__init__(
 | 
					        super().__init__(
 | 
				
			||||||
@ -122,9 +127,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        if self.rollout_buffer_class is None:
 | 
					        if self.rollout_buffer_class is None:
 | 
				
			||||||
            if isinstance(self.observation_space, spaces.Dict):
 | 
					            if isinstance(self.observation_space, spaces.Dict):
 | 
				
			||||||
                self.rollout_buffer_class = DictRolloutBuffer
 | 
					                self.rollout_buffer_class = BetterDictRolloutBuffer
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.rollout_buffer_class = RolloutBuffer
 | 
					                self.rollout_buffer_class = BetterRolloutBuffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.rollout_buffer = self.rollout_buffer_class(
 | 
					        self.rollout_buffer = self.rollout_buffer_class(
 | 
				
			||||||
            self.n_steps,
 | 
					            self.n_steps,
 | 
				
			||||||
@ -186,7 +191,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
            with th.no_grad():
 | 
					            with th.no_grad():
 | 
				
			||||||
                # Convert to pytorch tensor or to TensorDict
 | 
					                # Convert to pytorch tensor or to TensorDict
 | 
				
			||||||
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
 | 
					                obs_tensor = obs_as_tensor(self._last_obs, self.device)
 | 
				
			||||||
                actions, values, log_probs = self.policy(obs_tensor)
 | 
					                actions, values, log_probs, distributions = self.policy(obs_tensor, conditioned_log_probs=self.pca_is)
 | 
				
			||||||
            actions = actions.cpu().numpy()
 | 
					            actions = actions.cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Rescale and perform action
 | 
					            # Rescale and perform action
 | 
				
			||||||
@ -231,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
                        terminal_value = self.policy.predict_values(terminal_obs)[0]
 | 
					                        terminal_value = self.policy.predict_values(terminal_obs)[0]
 | 
				
			||||||
                    rewards[idx] += self.gamma * terminal_value
 | 
					                    rewards[idx] += self.gamma * terminal_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
 | 
					            rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
 | 
				
			||||||
            self._last_obs = new_obs
 | 
					            self._last_obs = new_obs
 | 
				
			||||||
            self._last_episode_starts = dones
 | 
					            self._last_episode_starts = dones
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -254,6 +259,30 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        observation: Union[np.ndarray, Dict[str, np.ndarray]],
 | 
				
			||||||
 | 
					        state: Optional[Tuple[np.ndarray, ...]] = None,
 | 
				
			||||||
 | 
					        episode_start: Optional[np.ndarray] = None,
 | 
				
			||||||
 | 
					        deterministic: bool = False,
 | 
				
			||||||
 | 
					        trajectory: th.Tensor = None,
 | 
				
			||||||
 | 
					    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Get the policy action from an observation (and optional hidden state).
 | 
				
			||||||
 | 
					        Includes sugar-coating to handle different observations (e.g. normalizing images).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :param observation: the input observation
 | 
				
			||||||
 | 
					        :param state: The last hidden states (can be None, used in recurrent policies)
 | 
				
			||||||
 | 
					        :param episode_start: The last masks (can be None, used in recurrent policies)
 | 
				
			||||||
 | 
					            this correspond to beginning of episodes,
 | 
				
			||||||
 | 
					            where the hidden states of the RNN must be reset.
 | 
				
			||||||
 | 
					        :param deterministic: Whether or not to return deterministic actions.
 | 
				
			||||||
 | 
					        :param trajectory: Past trajectory. Only required when using PCA.
 | 
				
			||||||
 | 
					        :return: the model's action and the next hidden state
 | 
				
			||||||
 | 
					            (used in recurrent policies)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.policy.predict(observation, state, episode_start, deterministic, trajectory=trajectory)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def learn(
 | 
					    def learn(
 | 
				
			||||||
        self: SelfOnPolicyAlgorithm,
 | 
					        self: SelfOnPolicyAlgorithm,
 | 
				
			||||||
        total_timesteps: int,
 | 
					        total_timesteps: int,
 | 
				
			||||||
 | 
				
			|||||||
@ -34,9 +34,12 @@ from stable_baselines3.common.torch_layers import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from stable_baselines3.common.policies import ContinuousCritic
 | 
					from stable_baselines3.common.policies import ContinuousCritic
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from stable_baselines3.common.type_aliases import Schedule
 | 
					from stable_baselines3.common.type_aliases import Schedule, RolloutBufferSamples
 | 
				
			||||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
 | 
					from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from metastable_projections.projections import BaseProjectionLayer, IdentityProjectionLayer, FrobeniusProjectionLay, WassersteinProjectionLayer, KLProjectionLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .distributions import make_proba_distribution
 | 
					from .distributions import make_proba_distribution
 | 
				
			||||||
from metastable_baselines2.common.pca import PCA_Distribution
 | 
					from metastable_baselines2.common.pca import PCA_Distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -395,7 +398,7 @@ class BasePolicy(BaseModel, ABC):
 | 
				
			|||||||
class ActorCriticPolicy(BasePolicy):
 | 
					class ActorCriticPolicy(BasePolicy):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Policy class for actor-critic algorithms (has both policy and value prediction).
 | 
					    Policy class for actor-critic algorithms (has both policy and value prediction).
 | 
				
			||||||
    Used by A2C, PPO and the likes.
 | 
					    Used by A2C, PPO, TRPL and the likes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    :param observation_space: Observation space
 | 
					    :param observation_space: Observation space
 | 
				
			||||||
    :param action_space: Action space
 | 
					    :param action_space: Action space
 | 
				
			||||||
@ -445,6 +448,7 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
 | 
					        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
 | 
				
			||||||
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
 | 
					        optimizer_kwargs: Optional[Dict[str, Any]] = None,
 | 
				
			||||||
        dist_kwargs: Optional[Dict[str, Any]] = {},
 | 
					        dist_kwargs: Optional[Dict[str, Any]] = {},
 | 
				
			||||||
 | 
					        policy_projection: BaseProjectionLayer = IdentityProjectionLayer(),
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        if optimizer_kwargs is None:
 | 
					        if optimizer_kwargs is None:
 | 
				
			||||||
            optimizer_kwargs = {}
 | 
					            optimizer_kwargs = {}
 | 
				
			||||||
@ -514,6 +518,8 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        self.use_pca = use_pca
 | 
					        self.use_pca = use_pca
 | 
				
			||||||
        self.dist_kwargs = dist_kwargs
 | 
					        self.dist_kwargs = dist_kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.policy_projection = policy_projection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Action distribution
 | 
					        # Action distribution
 | 
				
			||||||
        self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
 | 
					        self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, use_pca=use_pca, dist_kwargs=dist_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -624,7 +630,7 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        # Setup optimizer with initial learning rate
 | 
					        # Setup optimizer with initial learning rate
 | 
				
			||||||
        self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
 | 
					        self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
 | 
					    def forward(self, obs: th.Tensor, deterministic: bool = False, conditioned_log_probs: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Forward pass in all the networks (actor and critic)
 | 
					        Forward pass in all the networks (actor and critic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -644,9 +650,13 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        values = self.value_net(latent_vf)
 | 
					        values = self.value_net(latent_vf)
 | 
				
			||||||
        distribution = self._get_action_dist_from_latent(latent_pi)
 | 
					        distribution = self._get_action_dist_from_latent(latent_pi)
 | 
				
			||||||
        actions = distribution.get_actions(deterministic=deterministic)
 | 
					        actions = distribution.get_actions(deterministic=deterministic)
 | 
				
			||||||
        log_prob = distribution.log_prob(actions)
 | 
					        if conditioned_log_probs:
 | 
				
			||||||
 | 
					            assert self.use_pca, 'Cannot calculate conditioned log probs when PCA is disabled.'
 | 
				
			||||||
 | 
					            log_prob = distribution.conditioned_log_prob(actions)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            log_prob = distribution.log_prob(actions)
 | 
				
			||||||
        actions = actions.reshape((-1, *self.action_space.shape))
 | 
					        actions = actions.reshape((-1, *self.action_space.shape))
 | 
				
			||||||
        return actions, values, log_prob
 | 
					        return actions, values, log_prob, distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
 | 
					    def extract_features(self, obs: th.Tensor) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -691,7 +701,7 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError("Invalid action distribution")
 | 
					            raise ValueError("Invalid action distribution")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
 | 
					    def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Get the action according to the policy for a given observation.
 | 
					        Get the action according to the policy for a given observation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -699,19 +709,23 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        :param deterministic: Whether to use stochastic or deterministic actions
 | 
					        :param deterministic: Whether to use stochastic or deterministic actions
 | 
				
			||||||
        :return: Taken action according to the policy
 | 
					        :return: Taken action according to the policy
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return self.get_distribution(observation).get_actions(deterministic=deterministic)
 | 
					        if self.use_pca:
 | 
				
			||||||
 | 
					            return self.get_distribution(observation).get_actions(deterministic=deterministic, trajectory=trajectory)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.get_distribution(observation).get_actions(deterministic=deterministic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
 | 
					    def evaluate_actions(self, rollout_data: RolloutBufferSamples, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Evaluate actions according to the current policy,
 | 
					        Evaluate actions according to the current policy,
 | 
				
			||||||
        given the observations.
 | 
					        given the observations.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param obs: Observation
 | 
					        :param rollout_data: The Rollouts (containing )
 | 
				
			||||||
        :param actions: Actions
 | 
					        :param actions: Actions
 | 
				
			||||||
        :return: estimated value, log likelihood of taking those actions
 | 
					        :return: estimated value, log likelihood of taking those actions
 | 
				
			||||||
            and entropy of the action distribution.
 | 
					            and entropy of the action distribution.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # Preprocess the observation if needed
 | 
					        # Preprocess the observation if needed
 | 
				
			||||||
 | 
					        obs = rollout_data.observations
 | 
				
			||||||
        features = self.extract_features(obs)
 | 
					        features = self.extract_features(obs)
 | 
				
			||||||
        if self.share_features_extractor:
 | 
					        if self.share_features_extractor:
 | 
				
			||||||
            latent_pi, latent_vf = self.mlp_extractor(features)
 | 
					            latent_pi, latent_vf = self.mlp_extractor(features)
 | 
				
			||||||
@ -719,11 +733,13 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
            pi_features, vf_features = features
 | 
					            pi_features, vf_features = features
 | 
				
			||||||
            latent_pi = self.mlp_extractor.forward_actor(pi_features)
 | 
					            latent_pi = self.mlp_extractor.forward_actor(pi_features)
 | 
				
			||||||
            latent_vf = self.mlp_extractor.forward_critic(vf_features)
 | 
					            latent_vf = self.mlp_extractor.forward_critic(vf_features)
 | 
				
			||||||
        distribution = self._get_action_dist_from_latent(latent_pi)
 | 
					        raw_distribution = self._get_action_dist_from_latent(latent_pi)
 | 
				
			||||||
 | 
					        distribution, old_distribution = self.policy_projection.project_from_rollouts(raw_distribution, rollout_data)
 | 
				
			||||||
        log_prob = distribution.log_prob(actions)
 | 
					        log_prob = distribution.log_prob(actions)
 | 
				
			||||||
        values = self.value_net(latent_vf)
 | 
					        values = self.value_net(latent_vf)
 | 
				
			||||||
        entropy = distribution.entropy()
 | 
					        entropy = distribution.entropy()
 | 
				
			||||||
        return values, log_prob, entropy
 | 
					        trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
 | 
				
			||||||
 | 
					        return values, log_prob, entropy, trust_region_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_distribution(self, obs: th.Tensor) -> Distribution:
 | 
					    def get_distribution(self, obs: th.Tensor) -> Distribution:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -918,18 +934,18 @@ class Actor(BasePolicy):
 | 
				
			|||||||
        log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
 | 
					        log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
 | 
				
			||||||
        return mean_actions, log_std, {}
 | 
					        return mean_actions, log_std, {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
 | 
					    def forward(self, obs: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
 | 
				
			||||||
        mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
 | 
					        mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
 | 
				
			||||||
        # Note: the action is squashed
 | 
					        # Note: the action is squashed
 | 
				
			||||||
        return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
 | 
					        return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, trajectory=trajectory, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
 | 
					    def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
 | 
				
			||||||
        mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
 | 
					        mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
 | 
				
			||||||
        # return action and associated log prob
 | 
					        # return action and associated log prob
 | 
				
			||||||
        return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
 | 
					        return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
 | 
					    def _predict(self, observation: th.Tensor, deterministic: bool = False, trajectory: th.Tensor = None) -> th.Tensor:
 | 
				
			||||||
        return self(observation, deterministic)
 | 
					        return self(observation, deterministic, trajectory=trajectory)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SACPolicy(BasePolicy):
 | 
					class SACPolicy(BasePolicy):
 | 
				
			||||||
@ -1018,16 +1034,14 @@ class SACPolicy(BasePolicy):
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
        self.actor_kwargs = self.net_args.copy()
 | 
					        self.actor_kwargs = self.net_args.copy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        sde_kwargs = {
 | 
					        self.actor_kwargs.update({
 | 
				
			||||||
            "use_sde": use_sde,
 | 
					            "use_sde": use_sde,
 | 
				
			||||||
            "use_pca": use_pca,
 | 
					            "use_pca": use_pca,
 | 
				
			||||||
            "log_std_init": log_std_init,
 | 
					            "log_std_init": log_std_init,
 | 
				
			||||||
            "use_expln": use_expln,
 | 
					            "use_expln": use_expln,
 | 
				
			||||||
            "clip_mean": clip_mean,
 | 
					            "clip_mean": clip_mean,
 | 
				
			||||||
            "dist_kwargs": dist_kwargs,
 | 
					            "dist_kwargs": dist_kwargs,
 | 
				
			||||||
        }
 | 
					        })
 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.actor_kwargs.update(sde_kwargs)
 | 
					 | 
				
			||||||
        self.critic_kwargs = self.net_args.copy()
 | 
					        self.critic_kwargs = self.net_args.copy()
 | 
				
			||||||
        self.critic_kwargs.update(
 | 
					        self.critic_kwargs.update(
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user