SAC is back; with SDC; without Projections

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:07:50 +02:00
parent 5f32435751
commit 7b667e9650

View File

@ -11,7 +11,8 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
from ..misc.distTools import new_dist_like
@ -28,7 +29,7 @@ LOG_STD_MAX = 2
LOG_STD_MIN = -20
class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
class SAC(OffPolicyAlgorithm):
"""
Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
This implementation is almost a 1:1-copy of the sb3-code for SAC.
@ -127,14 +128,15 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
device: Union[th.device, str] = "auto",
# Different from SAC:
projection: BaseProjectionLayer = KLProjectionLayer(),
# projection: BaseProjectionLayer = BaseProjectionLayer(),
projection=None,
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
None, # PolicyBase
learning_rate,
buffer_size,
learning_starts,
@ -160,8 +162,6 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
support_multi_env=True,
)
raise Exception('TRL_SAC is not yet implemented')
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
# Entropy coefficient / Entropy temperature
@ -171,8 +171,13 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
self.ent_coef_optimizer = None
# Different from SAC:
self.projection = projection
# self.projection = projection
self._global_steps = 0
self.n_steps = buffer_size
self.gae_lambda = False
if projection != None:
print('[!] An projection was supplied! Will be ignored!')
if _init_setup_model:
self._setup_model()
@ -255,30 +260,32 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
act = self.actor
features = act.extract_features(replay_data.observations)
latent_pi = act.latent_pi(features)
mean_actions = act.mu(latent_pi)
mean_actions = act.mu_net(latent_pi)
# TODO: Allow contextual covariance with sde
if self.use_sde:
log_std = act.log_std
chol = act.chol
else:
# Unstructured exploration (Original implementation)
log_std = act.log_std(latent_pi)
chol = act.chol_net(latent_pi)
# Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
act.chol = chol
act_dist = self.action_dist
act_dist = self.actor.action_dist
# internal A
if self.use_sde:
actions_pi = self.actions_from_params(
mean_actions, log_std, latent_pi) # latent_pi = latent_sde
mean_actions, chol, latent_pi) # latent_pi = latent_sde
else:
actions_pi = act_dist.actions_from_params(
mean_actions, log_std)
mean_actions, chol)
p_dist = self.action_dist.distribution
q_dist = new_dist_like(
p_dist, replay_data.means, replay_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
p_dist = act_dist.distribution
# q_dist = new_dist_like(
# p_dist, replay_data.means, replay_data.stds)
#proj_p = self.projection(p_dist, q_dist, self._global_steps)
proj_p = p_dist
log_prob = proj_p.log_prob(actions_pi)
log_prob = log_prob.reshape(-1, 1)
@ -368,6 +375,23 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
if len(ent_coef_losses) > 0:
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
pol = self.policy.actor
if hasattr(pol, "log_std"):
self.logger.record(
"train/std", th.exp(pol.log_std).mean().item())
elif hasattr(pol, "chol"):
chol = pol.chol
if len(chol.shape) == 1:
self.logger.record(
"train/std", th.mean(chol).mean().item())
elif len(chol.shape) == 2:
self.logger.record(
"train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item())
else:
self.logger.record(
"train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item())
def learn(
self,
total_timesteps: int,