Implemented correct clipping (from original SAC)
This commit is contained in:
parent
ee4a0eed56
commit
6e1a7cecd5
@ -23,8 +23,8 @@ from stable_baselines3.common.type_aliases import Schedule
|
|||||||
from ..distributions import UniversalGaussianDistribution
|
from ..distributions import UniversalGaussianDistribution
|
||||||
|
|
||||||
# CAP the standard deviation of the actor
|
# CAP the standard deviation of the actor
|
||||||
LOG_STD_MAX = 2
|
CHOL_MIN = 0.001
|
||||||
LOG_STD_MIN = -20
|
CHOL_MAX = 1000
|
||||||
|
|
||||||
|
|
||||||
class Actor(BasePolicy):
|
class Actor(BasePolicy):
|
||||||
@ -203,7 +203,7 @@ class Actor(BasePolicy):
|
|||||||
chol = self.chol_net(latent_pi)
|
chol = self.chol_net(latent_pi)
|
||||||
self.chol = chol
|
self.chol = chol
|
||||||
# Original Implementation to cap the standard deviation
|
# Original Implementation to cap the standard deviation
|
||||||
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
self.chol = th.clamp(chol, CHOL_MIN, CHOL_MAX)
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
return mean_actions, self.chol, dict(latent_sde=latent_pi)
|
return mean_actions, self.chol, dict(latent_sde=latent_pi)
|
||||||
return mean_actions, self.chol, {}
|
return mean_actions, self.chol, {}
|
||||||
|
@ -262,10 +262,6 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
latent_pi = act.latent_pi(features)
|
latent_pi = act.latent_pi(features)
|
||||||
mean_actions = act.mu_net(latent_pi)
|
mean_actions = act.mu_net(latent_pi)
|
||||||
|
|
||||||
if self.use_sde:
|
|
||||||
chol = act.chol_net(latent_pi)
|
|
||||||
else:
|
|
||||||
# Unstructured exploration (Original implementation)
|
|
||||||
chol = act.chol_net(latent_pi)
|
chol = act.chol_net(latent_pi)
|
||||||
# Original Implementation to cap the standard deviation
|
# Original Implementation to cap the standard deviation
|
||||||
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
|
1
test.py
1
test.py
@ -41,6 +41,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
|
|||||||
#ent_coef=0.1, # 0.1
|
#ent_coef=0.1, # 0.1
|
||||||
#vf_coef=0.5,
|
#vf_coef=0.5,
|
||||||
use_sde=use_sde, # False
|
use_sde=use_sde, # False
|
||||||
|
sde_sample_freq=8,
|
||||||
#clip_range=None # 1 # 0.2,
|
#clip_range=None # 1 # 0.2,
|
||||||
)
|
)
|
||||||
# trl_frob = PPO(
|
# trl_frob = PPO(
|
||||||
|
Loading…
Reference in New Issue
Block a user