SAC is back; with SDC; without Projections
This commit is contained in:
		
							parent
							
								
									5f32435751
								
							
						
					
					
						commit
						7b667e9650
					
				| @ -11,7 +11,8 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm | ||||
| from stable_baselines3.common.policies import BasePolicy | ||||
| from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule | ||||
| from stable_baselines3.common.utils import polyak_update | ||||
| from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy | ||||
| 
 | ||||
| from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy | ||||
| 
 | ||||
| from ..misc.distTools import new_dist_like | ||||
| 
 | ||||
| @ -28,7 +29,7 @@ LOG_STD_MAX = 2 | ||||
| LOG_STD_MIN = -20 | ||||
| 
 | ||||
| 
 | ||||
| class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
| class SAC(OffPolicyAlgorithm): | ||||
|     """ | ||||
|     Trust Region Layers (TRL) based on SAC (Soft Actor Critic) | ||||
|     This implementation is almost a 1:1-copy of the sb3-code for SAC. | ||||
| @ -127,14 +128,15 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
|         device: Union[th.device, str] = "auto", | ||||
| 
 | ||||
|         # Different from SAC: | ||||
|         projection: BaseProjectionLayer = KLProjectionLayer(), | ||||
|         # projection: BaseProjectionLayer = BaseProjectionLayer(), | ||||
|         projection=None, | ||||
| 
 | ||||
|         _init_setup_model: bool = True, | ||||
|     ): | ||||
| 
 | ||||
|         super().__init__( | ||||
|             policy, | ||||
|             env, | ||||
|             None,  # PolicyBase | ||||
|             learning_rate, | ||||
|             buffer_size, | ||||
|             learning_starts, | ||||
| @ -160,8 +162,6 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
|             support_multi_env=True, | ||||
|         ) | ||||
| 
 | ||||
|         raise Exception('TRL_SAC is not yet implemented') | ||||
| 
 | ||||
|         self.target_entropy = target_entropy | ||||
|         self.log_ent_coef = None  # type: Optional[th.Tensor] | ||||
|         # Entropy coefficient / Entropy temperature | ||||
| @ -171,8 +171,13 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
|         self.ent_coef_optimizer = None | ||||
| 
 | ||||
|         # Different from SAC: | ||||
|         self.projection = projection | ||||
|         # self.projection = projection | ||||
|         self._global_steps = 0 | ||||
|         self.n_steps = buffer_size | ||||
|         self.gae_lambda = False | ||||
| 
 | ||||
|         if projection != None: | ||||
|             print('[!] An projection was supplied! Will be ignored!') | ||||
| 
 | ||||
|         if _init_setup_model: | ||||
|             self._setup_model() | ||||
| @ -255,30 +260,32 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
|             act = self.actor | ||||
|             features = act.extract_features(replay_data.observations) | ||||
|             latent_pi = act.latent_pi(features) | ||||
|             mean_actions = act.mu(latent_pi) | ||||
|             mean_actions = act.mu_net(latent_pi) | ||||
| 
 | ||||
|             # TODO: Allow contextual covariance with sde | ||||
|             if self.use_sde: | ||||
|                 log_std = act.log_std | ||||
|                 chol = act.chol | ||||
|             else: | ||||
|                 # Unstructured exploration (Original implementation) | ||||
|                 log_std = act.log_std(latent_pi) | ||||
|                 chol = act.chol_net(latent_pi) | ||||
|                 # Original Implementation to cap the standard deviation | ||||
|                 log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) | ||||
|                 chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) | ||||
|                 act.chol = chol | ||||
| 
 | ||||
|             act_dist = self.action_dist | ||||
|             act_dist = self.actor.action_dist | ||||
|             # internal A | ||||
|             if self.use_sde: | ||||
|                 actions_pi = self.actions_from_params( | ||||
|                     mean_actions, log_std, latent_pi)  # latent_pi = latent_sde | ||||
|                     mean_actions, chol, latent_pi)  # latent_pi = latent_sde | ||||
|             else: | ||||
|                 actions_pi = act_dist.actions_from_params( | ||||
|                     mean_actions, log_std) | ||||
|                     mean_actions, chol) | ||||
| 
 | ||||
|             p_dist = self.action_dist.distribution | ||||
|             q_dist = new_dist_like( | ||||
|                 p_dist, replay_data.means, replay_data.stds) | ||||
|             proj_p = self.projection(p_dist, q_dist, self._global_steps) | ||||
|             p_dist = act_dist.distribution | ||||
|             # q_dist = new_dist_like( | ||||
|             #    p_dist, replay_data.means, replay_data.stds) | ||||
|             #proj_p = self.projection(p_dist, q_dist, self._global_steps) | ||||
|             proj_p = p_dist | ||||
|             log_prob = proj_p.log_prob(actions_pi) | ||||
|             log_prob = log_prob.reshape(-1, 1) | ||||
| 
 | ||||
| @ -368,6 +375,23 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): | ||||
|         if len(ent_coef_losses) > 0: | ||||
|             self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) | ||||
| 
 | ||||
|         pol = self.policy.actor | ||||
| 
 | ||||
|         if hasattr(pol, "log_std"): | ||||
|             self.logger.record( | ||||
|                 "train/std", th.exp(pol.log_std).mean().item()) | ||||
|         elif hasattr(pol, "chol"): | ||||
|             chol = pol.chol | ||||
|             if len(chol.shape) == 1: | ||||
|                 self.logger.record( | ||||
|                     "train/std", th.mean(chol).mean().item()) | ||||
|             elif len(chol.shape) == 2: | ||||
|                 self.logger.record( | ||||
|                     "train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item()) | ||||
|             else: | ||||
|                 self.logger.record( | ||||
|                     "train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item()) | ||||
| 
 | ||||
|     def learn( | ||||
|         self, | ||||
|         total_timesteps: int, | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user