From ee4a0eed56d5545ca753598bad9a41225d235804 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 13:08:31 +0200 Subject: [PATCH] Fixed SAC+SDE+SDC bugs --- .../distributions/distributions.py | 9 +++---- metastable_baselines/sac/policies.py | 9 +++---- metastable_baselines/sac/sac.py | 7 +++--- test.py | 24 +++++++++---------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 75336f9..cfcf149 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -320,12 +320,12 @@ class UniversalGaussianDistribution(SB3_Distribution): self.gaussian_actions = mode return self.prob_squashing_type.apply(mode) - def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor: + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_sde=None) -> th.Tensor: # Update the proba distribution - self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi) + self.proba_distribution(mean_actions, log_std, latent_sde=latent_sde) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde=None) -> Tuple[th.Tensor, th.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. @@ -334,7 +334,8 @@ class UniversalGaussianDistribution(SB3_Distribution): :param log_std: :return: """ - actions = self.actions_from_params(mean_actions, log_std) + actions = self.actions_from_params( + mean_actions, log_std, latent_sde=latent_sde) log_prob = self.log_prob(actions, self.gaussian_actions) return actions, log_prob diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index 530702e..e1bd6eb 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -167,7 +167,7 @@ class Actor(BasePolicy): """ msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, - StateDependentNoiseDistribution), msg + StateDependentNoiseDistribution) or (isinstance(self.action_dist, UniversalGaussianDistribution) and self.action_dist.use_sde), msg return self.chol def reset_noise(self, n_envs: int = 1) -> None: @@ -199,12 +199,13 @@ class Actor(BasePolicy): latent_pi = self.latent_pi(features) mean_actions = self.mu_net(latent_pi) - if self.use_sde: - return mean_actions, self.chol, dict(latent_sde=latent_pi) # Unstructured exploration (Original implementation) chol = self.chol_net(latent_pi) + self.chol = chol # Original Implementation to cap the standard deviation - self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + # self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + if self.use_sde: + return mean_actions, self.chol, dict(latent_sde=latent_pi) return mean_actions, self.chol, {} def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index 6206dd3..e8aa378 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -262,9 +262,8 @@ class SAC(OffPolicyAlgorithm): latent_pi = act.latent_pi(features) mean_actions = act.mu_net(latent_pi) - # TODO: Allow contextual covariance with sde if self.use_sde: - chol = act.chol + chol = act.chol_net(latent_pi) else: # Unstructured exploration (Original implementation) chol = act.chol_net(latent_pi) @@ -275,8 +274,8 @@ class SAC(OffPolicyAlgorithm): act_dist = self.actor.action_dist # internal A if self.use_sde: - actions_pi = self.actions_from_params( - mean_actions, chol, latent_pi) # latent_pi = latent_sde + actions_pi = act_dist.actions_from_params( + mean_actions, chol, latent_sde=latent_pi) # latent_pi = latent_sde else: actions_pi = act_dist.actions_from_params( mean_actions, chol) diff --git a/test.py b/test.py index a3f2a95..d1ab583 100755 --- a/test.py +++ b/test.py @@ -22,26 +22,26 @@ root_path = '.' def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0): env = gym.make(env_name) - use_sde = False + use_sde = True # th.autograd.set_detect_anomaly(True) - ppo = PPO( - MlpPolicyPPO, + sac = SAC( + MlpPolicySAC, env, # KLProjectionLayer(trust_region_coeff=0.01), - projection=WassersteinProjectionLayer(trust_region_coeff=0.01), + #projection=WassersteinProjectionLayer(trust_region_coeff=0.01), policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type': ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0, tensorboard_log=root_path+"/logs_tb/" + - env_name+"/ppo"+(['', '_sde'][use_sde])+"/", + env_name+"/sac"+(['', '_sde'][use_sde])+"/", learning_rate=3e-4, # 3e-4, gamma=0.99, - gae_lambda=0.95, - normalize_advantage=True, - ent_coef=0.1, # 0.1 - vf_coef=0.5, + #gae_lambda=0.95, + #normalize_advantage=True, + #ent_coef=0.1, # 0.1 + #vf_coef=0.5, use_sde=use_sde, # False - clip_range=None # 1 # 0.2, + #clip_range=None # 1 # 0.2, ) # trl_frob = PPO( # MlpPolicy, @@ -60,8 +60,8 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru # clip_range=2, # 0.2 # ) - print('PPO:') - testModel(ppo, timesteps, showRes, + print('SAC:') + testModel(sac, timesteps, showRes, saveModel, n_eval_episodes) # print('TRL_frob:') # testModel(trl_frob, timesteps, showRes,