From 6e1a7cecd5a5845000f8f45cb2a35dba0b5a340f Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 13:46:17 +0200 Subject: [PATCH] Implemented correct clipping (from original SAC) --- metastable_baselines/sac/policies.py | 6 +++--- metastable_baselines/sac/sac.py | 12 ++++-------- test.py | 1 + 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/metastable_baselines/sac/policies.py b/metastable_baselines/sac/policies.py index e1bd6eb..3a6b297 100644 --- a/metastable_baselines/sac/policies.py +++ b/metastable_baselines/sac/policies.py @@ -23,8 +23,8 @@ from stable_baselines3.common.type_aliases import Schedule from ..distributions import UniversalGaussianDistribution # CAP the standard deviation of the actor -LOG_STD_MAX = 2 -LOG_STD_MIN = -20 +CHOL_MIN = 0.001 +CHOL_MAX = 1000 class Actor(BasePolicy): @@ -203,7 +203,7 @@ class Actor(BasePolicy): chol = self.chol_net(latent_pi) self.chol = chol # Original Implementation to cap the standard deviation - # self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + self.chol = th.clamp(chol, CHOL_MIN, CHOL_MAX) if self.use_sde: return mean_actions, self.chol, dict(latent_sde=latent_pi) return mean_actions, self.chol, {} diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index e8aa378..bd3ece5 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -262,14 +262,10 @@ class SAC(OffPolicyAlgorithm): latent_pi = act.latent_pi(features) mean_actions = act.mu_net(latent_pi) - if self.use_sde: - chol = act.chol_net(latent_pi) - else: - # Unstructured exploration (Original implementation) - chol = act.chol_net(latent_pi) - # Original Implementation to cap the standard deviation - chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) - act.chol = chol + chol = act.chol_net(latent_pi) + # Original Implementation to cap the standard deviation + chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) + act.chol = chol act_dist = self.actor.action_dist # internal A diff --git a/test.py b/test.py index d1ab583..b074e54 100755 --- a/test.py +++ b/test.py @@ -41,6 +41,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru #ent_coef=0.1, # 0.1 #vf_coef=0.5, use_sde=use_sde, # False + sde_sample_freq=8, #clip_range=None # 1 # 0.2, ) # trl_frob = PPO(