Minor bug fixes
This commit is contained in:
parent
3d2b7dfc8f
commit
5b9f8b028c
@ -102,8 +102,8 @@ class BetterRolloutBuffer(RolloutBuffer):
|
||||
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||
self.generator_ready = False
|
||||
super().reset()
|
||||
|
@ -236,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
|
@ -334,6 +334,7 @@ class BasePolicy(BaseModel, ABC):
|
||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||
episode_start: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False,
|
||||
trajectory = None,
|
||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||
"""
|
||||
Get the policy action from an observation (and optional hidden state).
|
||||
@ -738,7 +739,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
entropy = distribution.entropy()
|
||||
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
|
||||
trust_region_loss = self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
||||
return values, log_prob, entropy, trust_region_loss
|
||||
|
||||
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||
|
Loading…
Reference in New Issue
Block a user