_global_steps ctr added
This commit is contained in:
parent
866f863d70
commit
cafc90409f
@ -167,6 +167,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
|
||||
# Different from PPO:
|
||||
self.projection = projection
|
||||
self._global_steps = 0
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
@ -210,6 +211,10 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
# This is new compared to PPO.
|
||||
# Calculating the TR-Projections we need to know the step number
|
||||
self._global_steps += 1
|
||||
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
@ -241,7 +246,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
features = pol.extract_features(rollout_data.observations)
|
||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||
p = pol._get_action_dist_from_latent(latent_pi)
|
||||
# TODO: define b_q and global_step
|
||||
b_q = rollout_data.mean, rollout_data.std
|
||||
proj_p = self.projection(pol, p, b_q, self._global_step)
|
||||
log_prob = proj_p.log_prob(actions)
|
||||
# or log_prob = pol.log_probability(proj_p, actions)
|
||||
@ -466,8 +471,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
0]
|
||||
rewards[idx] += self.gamma * terminal_value
|
||||
|
||||
# TODO: calc mean + std
|
||||
rollout_buffer.add(self._last_obs, actions, rewards,
|
||||
self._last_episode_starts, values, log_probs)
|
||||
self._last_episode_starts, values, log_probs, mean, std)
|
||||
self._last_obs = new_obs
|
||||
self._last_episode_starts = dones
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user