diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index ee6cd1b..6244ef1 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -11,7 +11,8 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import polyak_update -from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy + +from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy from ..misc.distTools import new_dist_like @@ -28,7 +29,7 @@ LOG_STD_MAX = 2 LOG_STD_MIN = -20 -class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): +class SAC(OffPolicyAlgorithm): """ Trust Region Layers (TRL) based on SAC (Soft Actor Critic) This implementation is almost a 1:1-copy of the sb3-code for SAC. @@ -127,14 +128,15 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): device: Union[th.device, str] = "auto", # Different from SAC: - projection: BaseProjectionLayer = KLProjectionLayer(), + # projection: BaseProjectionLayer = BaseProjectionLayer(), + projection=None, _init_setup_model: bool = True, ): - super().__init__( policy, env, + None, # PolicyBase learning_rate, buffer_size, learning_starts, @@ -160,8 +162,6 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): support_multi_env=True, ) - raise Exception('TRL_SAC is not yet implemented') - self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] # Entropy coefficient / Entropy temperature @@ -171,8 +171,13 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): self.ent_coef_optimizer = None # Different from SAC: - self.projection = projection + # self.projection = projection self._global_steps = 0 + self.n_steps = buffer_size + self.gae_lambda = False + + if projection != None: + print('[!] An projection was supplied! Will be ignored!') if _init_setup_model: self._setup_model() @@ -255,30 +260,32 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): act = self.actor features = act.extract_features(replay_data.observations) latent_pi = act.latent_pi(features) - mean_actions = act.mu(latent_pi) + mean_actions = act.mu_net(latent_pi) # TODO: Allow contextual covariance with sde if self.use_sde: - log_std = act.log_std + chol = act.chol else: # Unstructured exploration (Original implementation) - log_std = act.log_std(latent_pi) + chol = act.chol_net(latent_pi) # Original Implementation to cap the standard deviation - log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + act.chol = chol - act_dist = self.action_dist + act_dist = self.actor.action_dist # internal A if self.use_sde: actions_pi = self.actions_from_params( - mean_actions, log_std, latent_pi) # latent_pi = latent_sde + mean_actions, chol, latent_pi) # latent_pi = latent_sde else: actions_pi = act_dist.actions_from_params( - mean_actions, log_std) + mean_actions, chol) - p_dist = self.action_dist.distribution - q_dist = new_dist_like( - p_dist, replay_data.means, replay_data.stds) - proj_p = self.projection(p_dist, q_dist, self._global_steps) + p_dist = act_dist.distribution + # q_dist = new_dist_like( + # p_dist, replay_data.means, replay_data.stds) + #proj_p = self.projection(p_dist, q_dist, self._global_steps) + proj_p = p_dist log_prob = proj_p.log_prob(actions_pi) log_prob = log_prob.reshape(-1, 1) @@ -368,6 +375,23 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) + pol = self.policy.actor + + if hasattr(pol, "log_std"): + self.logger.record( + "train/std", th.exp(pol.log_std).mean().item()) + elif hasattr(pol, "chol"): + chol = pol.chol + if len(chol.shape) == 1: + self.logger.record( + "train/std", th.mean(chol).mean().item()) + elif len(chol.shape) == 2: + self.logger.record( + "train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item()) + else: + self.logger.record( + "train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item()) + def learn( self, total_timesteps: int,