SAC is back; with SDC; without Projections

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:07:50 +02:00
parent 5f32435751
commit 7b667e9650

View File

@ -11,7 +11,8 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from ..misc.distTools import new_dist_like from ..misc.distTools import new_dist_like
@ -28,7 +29,7 @@ LOG_STD_MAX = 2
LOG_STD_MIN = -20 LOG_STD_MIN = -20
class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm): class SAC(OffPolicyAlgorithm):
""" """
Trust Region Layers (TRL) based on SAC (Soft Actor Critic) Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
This implementation is almost a 1:1-copy of the sb3-code for SAC. This implementation is almost a 1:1-copy of the sb3-code for SAC.
@ -127,14 +128,15 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
device: Union[th.device, str] = "auto", device: Union[th.device, str] = "auto",
# Different from SAC: # Different from SAC:
projection: BaseProjectionLayer = KLProjectionLayer(), # projection: BaseProjectionLayer = BaseProjectionLayer(),
projection=None,
_init_setup_model: bool = True, _init_setup_model: bool = True,
): ):
super().__init__( super().__init__(
policy, policy,
env, env,
None, # PolicyBase
learning_rate, learning_rate,
buffer_size, buffer_size,
learning_starts, learning_starts,
@ -160,8 +162,6 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
support_multi_env=True, support_multi_env=True,
) )
raise Exception('TRL_SAC is not yet implemented')
self.target_entropy = target_entropy self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor] self.log_ent_coef = None # type: Optional[th.Tensor]
# Entropy coefficient / Entropy temperature # Entropy coefficient / Entropy temperature
@ -171,8 +171,13 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
self.ent_coef_optimizer = None self.ent_coef_optimizer = None
# Different from SAC: # Different from SAC:
self.projection = projection # self.projection = projection
self._global_steps = 0 self._global_steps = 0
self.n_steps = buffer_size
self.gae_lambda = False
if projection != None:
print('[!] An projection was supplied! Will be ignored!')
if _init_setup_model: if _init_setup_model:
self._setup_model() self._setup_model()
@ -255,30 +260,32 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
act = self.actor act = self.actor
features = act.extract_features(replay_data.observations) features = act.extract_features(replay_data.observations)
latent_pi = act.latent_pi(features) latent_pi = act.latent_pi(features)
mean_actions = act.mu(latent_pi) mean_actions = act.mu_net(latent_pi)
# TODO: Allow contextual covariance with sde # TODO: Allow contextual covariance with sde
if self.use_sde: if self.use_sde:
log_std = act.log_std chol = act.chol
else: else:
# Unstructured exploration (Original implementation) # Unstructured exploration (Original implementation)
log_std = act.log_std(latent_pi) chol = act.chol_net(latent_pi)
# Original Implementation to cap the standard deviation # Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
act.chol = chol
act_dist = self.action_dist act_dist = self.actor.action_dist
# internal A # internal A
if self.use_sde: if self.use_sde:
actions_pi = self.actions_from_params( actions_pi = self.actions_from_params(
mean_actions, log_std, latent_pi) # latent_pi = latent_sde mean_actions, chol, latent_pi) # latent_pi = latent_sde
else: else:
actions_pi = act_dist.actions_from_params( actions_pi = act_dist.actions_from_params(
mean_actions, log_std) mean_actions, chol)
p_dist = self.action_dist.distribution p_dist = act_dist.distribution
q_dist = new_dist_like( # q_dist = new_dist_like(
p_dist, replay_data.means, replay_data.stds) # p_dist, replay_data.means, replay_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps) #proj_p = self.projection(p_dist, q_dist, self._global_steps)
proj_p = p_dist
log_prob = proj_p.log_prob(actions_pi) log_prob = proj_p.log_prob(actions_pi)
log_prob = log_prob.reshape(-1, 1) log_prob = log_prob.reshape(-1, 1)
@ -368,6 +375,23 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
if len(ent_coef_losses) > 0: if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
pol = self.policy.actor
if hasattr(pol, "log_std"):
self.logger.record(
"train/std", th.exp(pol.log_std).mean().item())
elif hasattr(pol, "chol"):
chol = pol.chol
if len(chol.shape) == 1:
self.logger.record(
"train/std", th.mean(chol).mean().item())
elif len(chol.shape) == 2:
self.logger.record(
"train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item())
else:
self.logger.record(
"train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item())
def learn( def learn(
self, self,
total_timesteps: int, total_timesteps: int,