Implemented correct clipping (from original SAC)

This commit is contained in:
Dominik Moritz Roth 2022-09-03 13:46:17 +02:00
parent ee4a0eed56
commit 6e1a7cecd5
3 changed files with 8 additions and 11 deletions

View File

@ -23,8 +23,8 @@ from stable_baselines3.common.type_aliases import Schedule
from ..distributions import UniversalGaussianDistribution
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
LOG_STD_MIN = -20
CHOL_MIN = 0.001
CHOL_MAX = 1000
class Actor(BasePolicy):
@ -203,7 +203,7 @@ class Actor(BasePolicy):
chol = self.chol_net(latent_pi)
self.chol = chol
# Original Implementation to cap the standard deviation
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
self.chol = th.clamp(chol, CHOL_MIN, CHOL_MAX)
if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi)
return mean_actions, self.chol, {}

View File

@ -262,10 +262,6 @@ class SAC(OffPolicyAlgorithm):
latent_pi = act.latent_pi(features)
mean_actions = act.mu_net(latent_pi)
if self.use_sde:
chol = act.chol_net(latent_pi)
else:
# Unstructured exploration (Original implementation)
chol = act.chol_net(latent_pi)
# Original Implementation to cap the standard deviation
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)

View File

@ -41,6 +41,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
#ent_coef=0.1, # 0.1
#vf_coef=0.5,
use_sde=use_sde, # False
sde_sample_freq=8,
#clip_range=None # 1 # 0.2,
)
# trl_frob = PPO(