Incremental progress at implementing trl_pg
This commit is contained in:
		
							parent
							
								
									941b7347f1
								
							
						
					
					
						commit
						25316ec0b8
					
				| @ -61,7 +61,7 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|         Setting it to auto, the code will be run on the GPU if possible. |         Setting it to auto, the code will be run on the GPU if possible. | ||||||
|     :param _init_setup_model: Whether or not to build the network at the creation of the instance |     :param _init_setup_model: Whether or not to build the network at the creation of the instance | ||||||
|     """ |     """ | ||||||
|     #TODO: Add new params to doc |     # TODO: Add new params to doc | ||||||
| 
 | 
 | ||||||
|     policy_aliases: Dict[str, Type[BasePolicy]] = { |     policy_aliases: Dict[str, Type[BasePolicy]] = { | ||||||
|         "MlpPolicy": ActorCriticPolicy, |         "MlpPolicy": ActorCriticPolicy, | ||||||
| @ -97,7 +97,7 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
| 
 | 
 | ||||||
|         # Different from PPO: |         # Different from PPO: | ||||||
|         #projection: BaseProjectionLayer = None, |         #projection: BaseProjectionLayer = None, | ||||||
|         projection = None, |         projection=None, | ||||||
| 
 | 
 | ||||||
|         _init_setup_model: bool = True, |         _init_setup_model: bool = True, | ||||||
|     ): |     ): | ||||||
| @ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|         clip_range = self.clip_range(self._current_progress_remaining) |         clip_range = self.clip_range(self._current_progress_remaining) | ||||||
|         # Optional: clip range for the value function |         # Optional: clip range for the value function | ||||||
|         if self.clip_range_vf is not None: |         if self.clip_range_vf is not None: | ||||||
|             clip_range_vf = self.clip_range_vf(self._current_progress_remaining) |             clip_range_vf = self.clip_range_vf( | ||||||
|  |                 self._current_progress_remaining) | ||||||
| 
 | 
 | ||||||
|         surrogate_losses = [] |         surrogate_losses = [] | ||||||
|         entropy_losses = [] |         entropy_losses = [] | ||||||
| @ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
| 
 | 
 | ||||||
|                 # here we go: |                 # here we go: | ||||||
|                 pol = self.policy |                 pol = self.policy | ||||||
|                 feat = pol.extract_features(rollout_data.observations) |                 features = pol.extract_features(rollout_data.observations) | ||||||
|                 latent_pi, latent_vf = pol.mlp_extractor(features) |                 latent_pi, latent_vf = pol.mlp_extractor(features) | ||||||
|                 p = pol._get_action_dist_from_latent(latent_pi) |                 p = pol._get_action_dist_from_latent(latent_pi) | ||||||
|                 proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step |                 # TODO: define b_q and global_step | ||||||
|                 log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) |                 proj_p = self.projection(pol, p, b_q, self._global_step) | ||||||
|  |                 log_prob = proj_p.log_prob(actions) | ||||||
|  |                 # or log_prob = pol.log_probability(proj_p, actions) | ||||||
|                 values = self.value_net(latent_vf) |                 values = self.value_net(latent_vf) | ||||||
|                 entropy = proj_p.entropy() # or not... |                 entropy = proj_p.entropy()  # or not... | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
|                 values = values.flatten() |                 values = values.flatten() | ||||||
|                 # Normalize advantage |                 # Normalize advantage | ||||||
|                 advantages = rollout_data.advantages |                 advantages = rollout_data.advantages | ||||||
|                 if self.normalize_advantage: |                 if self.normalize_advantage: | ||||||
|                     advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |                     advantages = (advantages - advantages.mean() | ||||||
|  |                                   ) / (advantages.std() + 1e-8) | ||||||
| 
 | 
 | ||||||
|                 # ratio between old and new policy, should be one at the first iteration |                 # ratio between old and new policy, should be one at the first iteration | ||||||
|                 ratio = th.exp(log_prob - rollout_data.old_log_prob) |                 ratio = th.exp(log_prob - rollout_data.old_log_prob) | ||||||
| @ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|                 # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' |                 # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' | ||||||
|                 # clipped surrogate loss |                 # clipped surrogate loss | ||||||
|                 surrogate_loss_1 = advantages * ratio |                 surrogate_loss_1 = advantages * ratio | ||||||
|                 surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) |                 surrogate_loss_2 = advantages * \ | ||||||
|                 surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean() |                     th.clamp(ratio, 1 - clip_range, 1 + clip_range) | ||||||
|  |                 surrogate_loss = - \ | ||||||
|  |                     th.min(surrogate_loss_1, surrogate_loss_2).mean() | ||||||
| 
 | 
 | ||||||
|                 surrogate_losses.append(surrogate_loss.item()) |                 surrogate_losses.append(surrogate_loss.item()) | ||||||
| 
 | 
 | ||||||
|                 clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() |                 clip_fraction = th.mean( | ||||||
|  |                     (th.abs(ratio - 1) > clip_range).float()).item() | ||||||
|                 clip_fractions.append(clip_fraction) |                 clip_fractions.append(clip_fraction) | ||||||
| 
 | 
 | ||||||
|                 if self.clip_range_vf is None: |                 if self.clip_range_vf is None: | ||||||
| @ -285,8 +291,9 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|                 entropy_losses.append(entropy_loss.item()) |                 entropy_losses.append(entropy_loss.item()) | ||||||
| 
 | 
 | ||||||
|                 # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss |                 # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss | ||||||
|                 #trust_region_loss = self.projection.get_trust_region_loss()#TODO: params |                 # trust_region_loss = self.projection.get_trust_region_loss()#TODO: params | ||||||
|                 trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement |                 trust_region_loss = th.zeros( | ||||||
|  |                     1, device=entropy_loss.device)  # TODO: Implement | ||||||
| 
 | 
 | ||||||
|                 trust_region_losses.append(trust_region_loss.item()) |                 trust_region_losses.append(trust_region_loss.item()) | ||||||
| 
 | 
 | ||||||
| @ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|                 # and Schulman blog: http://joschu.net/blog/kl-approx.html |                 # and Schulman blog: http://joschu.net/blog/kl-approx.html | ||||||
|                 with th.no_grad(): |                 with th.no_grad(): | ||||||
|                     log_ratio = log_prob - rollout_data.old_log_prob |                     log_ratio = log_prob - rollout_data.old_log_prob | ||||||
|                     approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() |                     approx_kl_div = th.mean( | ||||||
|  |                         (th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() | ||||||
|                     approx_kl_divs.append(approx_kl_div) |                     approx_kl_divs.append(approx_kl_div) | ||||||
| 
 | 
 | ||||||
|                 if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: |                 if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: | ||||||
|                     continue_training = False |                     continue_training = False | ||||||
|                     if self.verbose >= 1: |                     if self.verbose >= 1: | ||||||
|                         print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") |                         print( | ||||||
|  |                             f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") | ||||||
|                     break |                     break | ||||||
| 
 | 
 | ||||||
|                 # Optimization step |                 # Optimization step | ||||||
|                 self.policy.optimizer.zero_grad() |                 self.policy.optimizer.zero_grad() | ||||||
|                 loss.backward() |                 loss.backward() | ||||||
|                 # Clip grad norm |                 # Clip grad norm | ||||||
|                 th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) |                 th.nn.utils.clip_grad_norm_( | ||||||
|  |                     self.policy.parameters(), self.max_grad_norm) | ||||||
|                 self.policy.optimizer.step() |                 self.policy.optimizer.step() | ||||||
| 
 | 
 | ||||||
|             if not continue_training: |             if not continue_training: | ||||||
|                 break |                 break | ||||||
| 
 | 
 | ||||||
|         self._n_updates += self.n_epochs |         self._n_updates += self.n_epochs | ||||||
|         explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) |         explained_var = explained_variance( | ||||||
|  |             self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) | ||||||
| 
 | 
 | ||||||
|         # Logs |         # Logs | ||||||
|         self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) |         self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) | ||||||
|         self.logger.record("train/entropy_loss", np.mean(entropy_losses)) |         self.logger.record("train/entropy_loss", np.mean(entropy_losses)) | ||||||
|         self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) |         self.logger.record("train/trust_region_loss", | ||||||
|  |                            np.mean(trust_region_losses)) | ||||||
|         self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) |         self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) | ||||||
|         self.logger.record("train/value_loss", np.mean(value_losses)) |         self.logger.record("train/value_loss", np.mean(value_losses)) | ||||||
|         self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) |         self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) | ||||||
| @ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm): | |||||||
|         self.logger.record("train/loss", loss.item()) |         self.logger.record("train/loss", loss.item()) | ||||||
|         self.logger.record("train/explained_variance", explained_var) |         self.logger.record("train/explained_variance", explained_var) | ||||||
|         if hasattr(self.policy, "log_std"): |         if hasattr(self.policy, "log_std"): | ||||||
|             self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) |             self.logger.record( | ||||||
|  |                 "train/std", th.exp(self.policy.log_std).mean().item()) | ||||||
| 
 | 
 | ||||||
|         self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") |         self.logger.record("train/n_updates", | ||||||
|  |                            self._n_updates, exclude="tensorboard") | ||||||
|         self.logger.record("train/clip_range", clip_range) |         self.logger.record("train/clip_range", clip_range) | ||||||
|         if self.clip_range_vf is not None: |         if self.clip_range_vf is not None: | ||||||
|             self.logger.record("train/clip_range_vf", clip_range_vf) |             self.logger.record("train/clip_range_vf", clip_range_vf) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user