From 5b9f8b028ceebb74efdc6cbb2fc49bff522588b9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 22 Jan 2024 19:58:08 +0100 Subject: [PATCH] Minor bug fixes --- metastable_baselines2/common/buffers.py | 4 ++-- metastable_baselines2/common/on_policy_algorithm.py | 2 +- metastable_baselines2/common/policies.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/metastable_baselines2/common/buffers.py b/metastable_baselines2/common/buffers.py index 3b9aeb6..500167f 100644 --- a/metastable_baselines2/common/buffers.py +++ b/metastable_baselines2/common/buffers.py @@ -102,8 +102,8 @@ class BetterRolloutBuffer(RolloutBuffer): self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) - self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.generator_ready = False super().reset() diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index 0320526..a145852 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -236,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): terminal_value = self.policy.predict_values(terminal_obs)[0] rewards[idx] += self.gamma * terminal_value - rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale) + rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale) self._last_obs = new_obs self._last_episode_starts = dones diff --git a/metastable_baselines2/common/policies.py b/metastable_baselines2/common/policies.py index 40db45d..9fba2da 100644 --- a/metastable_baselines2/common/policies.py +++ b/metastable_baselines2/common/policies.py @@ -334,6 +334,7 @@ class BasePolicy(BaseModel, ABC): state: Optional[Tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, + trajectory = None, ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). @@ -738,7 +739,7 @@ class ActorCriticPolicy(BasePolicy): log_prob = distribution.log_prob(actions) values = self.value_net(latent_vf) entropy = distribution.entropy() - trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution) + trust_region_loss = self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution) return values, log_prob, entropy, trust_region_loss def get_distribution(self, obs: th.Tensor) -> Distribution: