Minor bug fixes
This commit is contained in:
parent
3d2b7dfc8f
commit
5b9f8b028c
@ -102,8 +102,8 @@ class BetterRolloutBuffer(RolloutBuffer):
|
|||||||
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.means = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.means = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.cov_decomps = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
|
||||||
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
|
||||||
self.generator_ready = False
|
self.generator_ready = False
|
||||||
super().reset()
|
super().reset()
|
||||||
|
@ -236,7 +236,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
terminal_value = self.policy.predict_values(terminal_obs)[0]
|
||||||
rewards[idx] += self.gamma * terminal_value
|
rewards[idx] += self.gamma * terminal_value
|
||||||
|
|
||||||
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.mean, distributions.scale)
|
rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs, distributions.distribution.mean, distributions.distribution.scale)
|
||||||
self._last_obs = new_obs
|
self._last_obs = new_obs
|
||||||
self._last_episode_starts = dones
|
self._last_episode_starts = dones
|
||||||
|
|
||||||
|
@ -334,6 +334,7 @@ class BasePolicy(BaseModel, ABC):
|
|||||||
state: Optional[Tuple[np.ndarray, ...]] = None,
|
state: Optional[Tuple[np.ndarray, ...]] = None,
|
||||||
episode_start: Optional[np.ndarray] = None,
|
episode_start: Optional[np.ndarray] = None,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
|
trajectory = None,
|
||||||
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
|
||||||
"""
|
"""
|
||||||
Get the policy action from an observation (and optional hidden state).
|
Get the policy action from an observation (and optional hidden state).
|
||||||
@ -738,7 +739,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
log_prob = distribution.log_prob(actions)
|
log_prob = distribution.log_prob(actions)
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
entropy = distribution.entropy()
|
entropy = distribution.entropy()
|
||||||
trust_region_loss = self.projection.get_trust_region_loss(raw_distribution, old_distribution)
|
trust_region_loss = self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
||||||
return values, log_prob, entropy, trust_region_loss
|
return values, log_prob, entropy, trust_region_loss
|
||||||
|
|
||||||
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
def get_distribution(self, obs: th.Tensor) -> Distribution:
|
||||||
|
Loading…
Reference in New Issue
Block a user