Implemented correct clipping (from original SAC)

This commit is contained in:
Dominik Moritz Roth 2022-09-03 13:46:17 +02:00
parent ee4a0eed56
commit 6e1a7cecd5
3 changed files with 8 additions and 11 deletions

View File

@ -23,8 +23,8 @@ from stable_baselines3.common.type_aliases import Schedule
from ..distributions import UniversalGaussianDistribution from ..distributions import UniversalGaussianDistribution
# CAP the standard deviation of the actor # CAP the standard deviation of the actor
LOG_STD_MAX = 2 CHOL_MIN = 0.001
LOG_STD_MIN = -20 CHOL_MAX = 1000
class Actor(BasePolicy): class Actor(BasePolicy):
@ -203,7 +203,7 @@ class Actor(BasePolicy):
chol = self.chol_net(latent_pi) chol = self.chol_net(latent_pi)
self.chol = chol self.chol = chol
# Original Implementation to cap the standard deviation # Original Implementation to cap the standard deviation
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) self.chol = th.clamp(chol, CHOL_MIN, CHOL_MAX)
if self.use_sde: if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi) return mean_actions, self.chol, dict(latent_sde=latent_pi)
return mean_actions, self.chol, {} return mean_actions, self.chol, {}

View File

@ -262,14 +262,10 @@ class SAC(OffPolicyAlgorithm):
latent_pi = act.latent_pi(features) latent_pi = act.latent_pi(features)
mean_actions = act.mu_net(latent_pi) mean_actions = act.mu_net(latent_pi)
if self.use_sde: chol = act.chol_net(latent_pi)
chol = act.chol_net(latent_pi) # Original Implementation to cap the standard deviation
else: chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
# Unstructured exploration (Original implementation) act.chol = chol
chol = act.chol_net(latent_pi)
# Original Implementation to cap the standard deviation
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
act.chol = chol
act_dist = self.actor.action_dist act_dist = self.actor.action_dist
# internal A # internal A

View File

@ -41,6 +41,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
#ent_coef=0.1, # 0.1 #ent_coef=0.1, # 0.1
#vf_coef=0.5, #vf_coef=0.5,
use_sde=use_sde, # False use_sde=use_sde, # False
sde_sample_freq=8,
#clip_range=None # 1 # 0.2, #clip_range=None # 1 # 0.2,
) )
# trl_frob = PPO( # trl_frob = PPO(