diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index f26bea3..37d86b2 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -167,6 +167,7 @@ class TRL_PG(OnPolicyAlgorithm): # Different from PPO: self.projection = projection + self._global_steps = 0 if _init_setup_model: self._setup_model() @@ -210,6 +211,10 @@ class TRL_PG(OnPolicyAlgorithm): approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): + # This is new compared to PPO. + # Calculating the TR-Projections we need to know the step number + self._global_steps += 1 + actions = rollout_data.actions if isinstance(self.action_space, spaces.Discrete): # Convert discrete action from float to long @@ -241,7 +246,7 @@ class TRL_PG(OnPolicyAlgorithm): features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) - # TODO: define b_q and global_step + b_q = rollout_data.mean, rollout_data.std proj_p = self.projection(pol, p, b_q, self._global_step) log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) @@ -466,8 +471,9 @@ class TRL_PG(OnPolicyAlgorithm): 0] rewards[idx] += self.gamma * terminal_value + # TODO: calc mean + std rollout_buffer.add(self._last_obs, actions, rewards, - self._last_episode_starts, values, log_probs) + self._last_episode_starts, values, log_probs, mean, std) self._last_obs = new_obs self._last_episode_starts = dones