Minor bug fixes

This commit is contained in:
Dominik Moritz Roth 2024-01-22 19:58:08 +01:00
parent 3d2b7dfc8f
commit 5b9f8b028c
3 changed files with 5 additions and 4 deletions

View File

@ -102,8 +102,8 @@ class BetterRolloutBuffer(RolloutBuffer):
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.generator_ready = False
super().reset()

View File

@ -236,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
terminal_value = self.policy.predict_values(terminal_obs)[0]
rewards[idx] += self.gamma * terminal_value
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
self._last_obs = new_obs
self._last_episode_starts = dones

View File

@ -334,6 +334,7 @@ class BasePolicy(BaseModel, ABC):
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
trajectory = None,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
@ -738,7 +739,7 @@ class ActorCriticPolicy(BasePolicy):
log_prob = distribution.log_prob(actions)
values = self.value_net(latent_vf)
entropy = distribution.entropy()
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
trust_region_loss = self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
return values, log_prob, entropy, trust_region_loss
def get_distribution(self, obs: th.Tensor) -> Distribution: