From 25316ec0b84797d0b0bcffdf77eea56173ea68b6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 25 Jun 2022 13:58:58 +0200 Subject: [PATCH] Incremental progress at implementing trl_pg --- sb3_trl/trl_pg/trl_pg.py | 56 +++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 41405b7..99932e3 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -61,7 +61,7 @@ class TRL_PG(OnPolicyAlgorithm): Setting it to auto, the code will be run on the GPU if possible. :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - #TODO: Add new params to doc + # TODO: Add new params to doc policy_aliases: Dict[str, Type[BasePolicy]] = { "MlpPolicy": ActorCriticPolicy, @@ -97,7 +97,7 @@ class TRL_PG(OnPolicyAlgorithm): # Different from PPO: #projection: BaseProjectionLayer = None, - projection = None, + projection=None, _init_setup_model: bool = True, ): @@ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm): clip_range = self.clip_range(self._current_progress_remaining) # Optional: clip range for the value function if self.clip_range_vf is not None: - clip_range_vf = self.clip_range_vf(self._current_progress_remaining) + clip_range_vf = self.clip_range_vf( + self._current_progress_remaining) surrogate_losses = [] entropy_losses = [] @@ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm): # here we go: pol = self.policy - feat = pol.extract_features(rollout_data.observations) + features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) p = pol._get_action_dist_from_latent(latent_pi) - proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step - log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) + # TODO: define b_q and global_step + proj_p = self.projection(pol, p, b_q, self._global_step) + log_prob = proj_p.log_prob(actions) + # or log_prob = pol.log_probability(proj_p, actions) values = self.value_net(latent_vf) - entropy = proj_p.entropy() # or not... - + entropy = proj_p.entropy() # or not... values = values.flatten() # Normalize advantage advantages = rollout_data.advantages if self.normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + advantages = (advantages - advantages.mean() + ) / (advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm): # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # clipped surrogate loss surrogate_loss_1 = advantages * ratio - surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) - surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean() + surrogate_loss_2 = advantages * \ + th.clamp(ratio, 1 - clip_range, 1 + clip_range) + surrogate_loss = - \ + th.min(surrogate_loss_1, surrogate_loss_2).mean() surrogate_losses.append(surrogate_loss.item()) - clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() + clip_fraction = th.mean( + (th.abs(ratio - 1) > clip_range).float()).item() clip_fractions.append(clip_fraction) if self.clip_range_vf is None: @@ -285,8 +291,9 @@ class TRL_PG(OnPolicyAlgorithm): entropy_losses.append(entropy_loss.item()) # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss - #trust_region_loss = self.projection.get_trust_region_loss()#TODO: params - trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement + # trust_region_loss = self.projection.get_trust_region_loss()#TODO: params + trust_region_loss = th.zeros( + 1, device=entropy_loss.device) # TODO: Implement trust_region_losses.append(trust_region_loss.item()) @@ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm): # and Schulman blog: http://joschu.net/blog/kl-approx.html with th.no_grad(): log_ratio = log_prob - rollout_data.old_log_prob - approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() + approx_kl_div = th.mean( + (th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_divs.append(approx_kl_div) if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: continue_training = False if self.verbose >= 1: - print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") + print( + f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") break # Optimization step self.policy.optimizer.zero_grad() loss.backward() # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + th.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: break self._n_updates += self.n_epochs - explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + explained_var = explained_variance( + self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) self.logger.record("train/entropy_loss", np.mean(entropy_losses)) - self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) + self.logger.record("train/trust_region_loss", + np.mean(trust_region_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) @@ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm): self.logger.record("train/loss", loss.item()) self.logger.record("train/explained_variance", explained_var) if hasattr(self.policy, "log_std"): - self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + self.logger.record( + "train/std", th.exp(self.policy.log_std).mean().item()) - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/n_updates", + self._n_updates, exclude="tensorboard") self.logger.record("train/clip_range", clip_range) if self.clip_range_vf is not None: self.logger.record("train/clip_range_vf", clip_range_vf)