_global_steps ctr added

This commit is contained in:
Dominik Moritz Roth 2022-06-25 15:05:54 +02:00
parent 866f863d70
commit cafc90409f

View File

@ -167,6 +167,7 @@ class TRL_PG(OnPolicyAlgorithm):
# Different from PPO: # Different from PPO:
self.projection = projection self.projection = projection
self._global_steps = 0
if _init_setup_model: if _init_setup_model:
self._setup_model() self._setup_model()
@ -210,6 +211,10 @@ class TRL_PG(OnPolicyAlgorithm):
approx_kl_divs = [] approx_kl_divs = []
# Do a complete pass on the rollout buffer # Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size): for rollout_data in self.rollout_buffer.get(self.batch_size):
# This is new compared to PPO.
# Calculating the TR-Projections we need to know the step number
self._global_steps += 1
actions = rollout_data.actions actions = rollout_data.actions
if isinstance(self.action_space, spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
# Convert discrete action from float to long # Convert discrete action from float to long
@ -241,7 +246,7 @@ class TRL_PG(OnPolicyAlgorithm):
features = pol.extract_features(rollout_data.observations) features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features) latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi) p = pol._get_action_dist_from_latent(latent_pi)
# TODO: define b_q and global_step b_q = rollout_data.mean, rollout_data.std
proj_p = self.projection(pol, p, b_q, self._global_step) proj_p = self.projection(pol, p, b_q, self._global_step)
log_prob = proj_p.log_prob(actions) log_prob = proj_p.log_prob(actions)
# or log_prob = pol.log_probability(proj_p, actions) # or log_prob = pol.log_probability(proj_p, actions)
@ -466,8 +471,9 @@ class TRL_PG(OnPolicyAlgorithm):
0] 0]
rewards[idx] += self.gamma * terminal_value rewards[idx] += self.gamma * terminal_value
# TODO: calc mean + std
rollout_buffer.add(self._last_obs, actions, rewards, rollout_buffer.add(self._last_obs, actions, rewards,
self._last_episode_starts, values, log_probs) self._last_episode_starts, values, log_probs, mean, std)
self._last_obs = new_obs self._last_obs = new_obs
self._last_episode_starts = dones self._last_episode_starts = dones