Fixed SAC+SDE+SDC bugs

This commit is contained in:
Dominik Moritz Roth 2022-09-03 13:08:31 +02:00
parent 4532135812
commit ee4a0eed56
4 changed files with 25 additions and 24 deletions

View File

@ -320,12 +320,12 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.gaussian_actions = mode self.gaussian_actions = mode
return self.prob_squashing_type.apply(mode) return self.prob_squashing_type.apply(mode)
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_pi=None) -> th.Tensor: def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False, latent_sde=None) -> th.Tensor:
# Update the proba distribution # Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_pi=latent_pi) self.proba_distribution(mean_actions, log_std, latent_sde=latent_sde)
return self.get_actions(deterministic=deterministic) return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde=None) -> Tuple[th.Tensor, th.Tensor]:
""" """
Compute the log probability of taking an action Compute the log probability of taking an action
given the distribution parameters. given the distribution parameters.
@ -334,7 +334,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param log_std: :param log_std:
:return: :return:
""" """
actions = self.actions_from_params(mean_actions, log_std) actions = self.actions_from_params(
mean_actions, log_std, latent_sde=latent_sde)
log_prob = self.log_prob(actions, self.gaussian_actions) log_prob = self.log_prob(actions, self.gaussian_actions)
return actions, log_prob return actions, log_prob

View File

@ -167,7 +167,7 @@ class Actor(BasePolicy):
""" """
msg = "get_std() is only available when using gSDE" msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, assert isinstance(self.action_dist,
StateDependentNoiseDistribution), msg StateDependentNoiseDistribution) or (isinstance(self.action_dist, UniversalGaussianDistribution) and self.action_dist.use_sde), msg
return self.chol return self.chol
def reset_noise(self, n_envs: int = 1) -> None: def reset_noise(self, n_envs: int = 1) -> None:
@ -199,12 +199,13 @@ class Actor(BasePolicy):
latent_pi = self.latent_pi(features) latent_pi = self.latent_pi(features)
mean_actions = self.mu_net(latent_pi) mean_actions = self.mu_net(latent_pi)
if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation) # Unstructured exploration (Original implementation)
chol = self.chol_net(latent_pi) chol = self.chol_net(latent_pi)
self.chol = chol
# Original Implementation to cap the standard deviation # Original Implementation to cap the standard deviation
self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX) # self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
if self.use_sde:
return mean_actions, self.chol, dict(latent_sde=latent_pi)
return mean_actions, self.chol, {} return mean_actions, self.chol, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:

View File

@ -262,9 +262,8 @@ class SAC(OffPolicyAlgorithm):
latent_pi = act.latent_pi(features) latent_pi = act.latent_pi(features)
mean_actions = act.mu_net(latent_pi) mean_actions = act.mu_net(latent_pi)
# TODO: Allow contextual covariance with sde
if self.use_sde: if self.use_sde:
chol = act.chol chol = act.chol_net(latent_pi)
else: else:
# Unstructured exploration (Original implementation) # Unstructured exploration (Original implementation)
chol = act.chol_net(latent_pi) chol = act.chol_net(latent_pi)
@ -275,8 +274,8 @@ class SAC(OffPolicyAlgorithm):
act_dist = self.actor.action_dist act_dist = self.actor.action_dist
# internal A # internal A
if self.use_sde: if self.use_sde:
actions_pi = self.actions_from_params( actions_pi = act_dist.actions_from_params(
mean_actions, chol, latent_pi) # latent_pi = latent_sde mean_actions, chol, latent_sde=latent_pi) # latent_pi = latent_sde
else: else:
actions_pi = act_dist.actions_from_params( actions_pi = act_dist.actions_from_params(
mean_actions, chol) mean_actions, chol)

24
test.py
View File

@ -22,26 +22,26 @@ root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0): def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name) env = gym.make(env_name)
use_sde = False use_sde = True
# th.autograd.set_detect_anomaly(True) # th.autograd.set_detect_anomaly(True)
ppo = PPO( sac = SAC(
MlpPolicyPPO, MlpPolicySAC,
env, env,
# KLProjectionLayer(trust_region_coeff=0.01), # KLProjectionLayer(trust_region_coeff=0.01),
projection=WassersteinProjectionLayer(trust_region_coeff=0.01), #projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type': policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/" + tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/", env_name+"/sac"+(['', '_sde'][use_sde])+"/",
learning_rate=3e-4, # 3e-4, learning_rate=3e-4, # 3e-4,
gamma=0.99, gamma=0.99,
gae_lambda=0.95, #gae_lambda=0.95,
normalize_advantage=True, #normalize_advantage=True,
ent_coef=0.1, # 0.1 #ent_coef=0.1, # 0.1
vf_coef=0.5, #vf_coef=0.5,
use_sde=use_sde, # False use_sde=use_sde, # False
clip_range=None # 1 # 0.2, #clip_range=None # 1 # 0.2,
) )
# trl_frob = PPO( # trl_frob = PPO(
# MlpPolicy, # MlpPolicy,
@ -60,8 +60,8 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
# clip_range=2, # 0.2 # clip_range=2, # 0.2
# ) # )
print('PPO:') print('SAC:')
testModel(ppo, timesteps, showRes, testModel(sac, timesteps, showRes,
saveModel, n_eval_episodes) saveModel, n_eval_episodes)
# print('TRL_frob:') # print('TRL_frob:')
# testModel(trl_frob, timesteps, showRes, # testModel(trl_frob, timesteps, showRes,