Implemented correct clipping (from original SAC)
This commit is contained in:
parent
ee4a0eed56
commit
6e1a7cecd5
@ -23,8 +23,8 @@ from stable_baselines3.common.type_aliases import Schedule
|
||||
from ..distributions import UniversalGaussianDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
LOG_STD_MAX = 2
|
||||
LOG_STD_MIN = -20
|
||||
CHOL_MIN = 0.001
|
||||
CHOL_MAX = 1000
|
||||
|
||||
|
||||
class Actor(BasePolicy):
|
||||
@ -203,7 +203,7 @@ class Actor(BasePolicy):
|
||||
chol = self.chol_net(latent_pi)
|
||||
self.chol = chol
|
||||
# Original Implementation to cap the standard deviation
|
||||
# self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
self.chol = th.clamp(chol, CHOL_MIN, CHOL_MAX)
|
||||
if self.use_sde:
|
||||
return mean_actions, self.chol, dict(latent_sde=latent_pi)
|
||||
return mean_actions, self.chol, {}
|
||||
|
@ -262,14 +262,10 @@ class SAC(OffPolicyAlgorithm):
|
||||
latent_pi = act.latent_pi(features)
|
||||
mean_actions = act.mu_net(latent_pi)
|
||||
|
||||
if self.use_sde:
|
||||
chol = act.chol_net(latent_pi)
|
||||
else:
|
||||
# Unstructured exploration (Original implementation)
|
||||
chol = act.chol_net(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
act.chol = chol
|
||||
chol = act.chol_net(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
act.chol = chol
|
||||
|
||||
act_dist = self.actor.action_dist
|
||||
# internal A
|
||||
|
Loading…
Reference in New Issue
Block a user