From cafc90409f950c5ed01a25058a7ce54fd2b4b01f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 25 Jun 2022 15:05:54 +0200 Subject: [PATCH] _global_steps ctr added --- sb3_trl/trl_pg/trl_pg.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index f26bea3..37d86b2 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -167,6 +167,7 @@ class TRL_PG(OnPolicyAlgorithm): # Different from PPO: self.projection = projection + self._global_steps = 0 if _init_setup_model: self._setup_model() @@ -210,6 +211,10 @@ class TRL_PG(OnPolicyAlgorithm): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): + # This is new compared to PPO. + # Calculating the TR-Projections we need to know the step number + self._global_steps += 1 + actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long @@ -241,7 +246,7 @@ class TRL_PG(OnPolicyAlgorithm): features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) - # TODO: define b_q and global_step + b_q = rollout_data.mean, rollout_data.std proj_p = self.projection(pol, p, b_q, self._global_step) log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) @@ -466,8 +471,9 @@ class TRL_PG(OnPolicyAlgorithm): 0] rewards[idx] += self.gamma * terminal_value + # TODO: calc mean + std rollout_buffer.add(self._last_obs, actions, rewards, - self._last_episode_starts, values, log_probs) + self._last_episode_starts, values, log_probs, mean, std) self._last_obs = new_obs self._last_episode_starts = dones