From 4831fbb7e184063bef175dc069856fd322c3a1bc Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 30 Mar 2024 14:42:41 +0100 Subject: [PATCH] Support Full Cov in Buffer --- metastable_baselines2/common/buffers.py | 8 ++++++- .../common/on_policy_algorithm.py | 21 ++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py index 18dbb41..5407688 100644 --- a/metastable_baselines2/common/buffers.py +++ b/metastable_baselines2/common/buffers.py @@ -87,7 +87,9 @@ class BetterRolloutBuffer(RolloutBuffer): gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, + full_cov: bool = False, ): + self.full_cov = full_cov super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma @@ -103,7 +105,10 @@ class BetterRolloutBuffer(RolloutBuffer): self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) - self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + if self.full_cov: + self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim, self.action_dim), dtype=np.float32) + else: + self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False super().reset() @@ -156,6 +161,7 @@ class BetterRolloutBuffer(RolloutBuffer): mean: th.Tensor, cov_decomp: th.Tensor, ) -> None: + """ :param obs: Observation :param action: Action diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index 2af3175..7fd17ee 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -120,7 +120,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): if 'dist_kwargs' not in self.policy_kwargs: self.policy_kwargs['dist_kwargs'] = {} - self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs) + self.policy_kwargs['dist_kwargs']['n_envs'] = self.env.num_envs if hasattr(self.env, 'num_envs') else 1 self.rollout_buffer_class = None self.rollout_buffer_kwargs = {} @@ -135,6 +135,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): else: self.rollout_buffer_class = BetterRolloutBuffer + self.policy = self.policy_class( # pytype:disable=not-instantiable + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + use_pca=self.use_pca, + **self.policy_kwargs # pytype:disable=not-instantiable + ) self.rollout_buffer = self.rollout_buffer_class( self.n_steps, self.observation_space, @@ -143,16 +151,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + full_cov=(self.use_pca and self.policy.action_dist.is_full()), **self.rollout_buffer_kwargs, ) - self.policy = self.policy_class( # pytype:disable=not-instantiable - self.observation_space, - self.action_space, - self.lr_schedule, - use_sde=self.use_sde, - use_pca=self.use_pca, - **self.policy_kwargs # pytype:disable=not-instantiable - ) self.policy = self.policy.to(self.device) def collect_rollouts( @@ -240,7 +241,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale) + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale if hasattr(distributions.distribution, 'scale') else distributions.distribution.scale_tril) self._last_obs = new_obs self._last_episode_starts = dones