SAC is back; with SDC; without Projections
This commit is contained in:
parent
5f32435751
commit
7b667e9650
@ -11,7 +11,8 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
from metastable_baselines.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy
|
||||
|
||||
from ..misc.distTools import new_dist_like
|
||||
|
||||
@ -28,7 +29,7 @@ LOG_STD_MAX = 2
|
||||
LOG_STD_MIN = -20
|
||||
|
||||
|
||||
class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
"""
|
||||
Trust Region Layers (TRL) based on SAC (Soft Actor Critic)
|
||||
This implementation is almost a 1:1-copy of the sb3-code for SAC.
|
||||
@ -127,14 +128,15 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
device: Union[th.device, str] = "auto",
|
||||
|
||||
# Different from SAC:
|
||||
projection: BaseProjectionLayer = KLProjectionLayer(),
|
||||
# projection: BaseProjectionLayer = BaseProjectionLayer(),
|
||||
projection=None,
|
||||
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
None, # PolicyBase
|
||||
learning_rate,
|
||||
buffer_size,
|
||||
learning_starts,
|
||||
@ -160,8 +162,6 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
support_multi_env=True,
|
||||
)
|
||||
|
||||
raise Exception('TRL_SAC is not yet implemented')
|
||||
|
||||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||
# Entropy coefficient / Entropy temperature
|
||||
@ -171,8 +171,13 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
self.ent_coef_optimizer = None
|
||||
|
||||
# Different from SAC:
|
||||
self.projection = projection
|
||||
# self.projection = projection
|
||||
self._global_steps = 0
|
||||
self.n_steps = buffer_size
|
||||
self.gae_lambda = False
|
||||
|
||||
if projection != None:
|
||||
print('[!] An projection was supplied! Will be ignored!')
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
@ -255,30 +260,32 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
act = self.actor
|
||||
features = act.extract_features(replay_data.observations)
|
||||
latent_pi = act.latent_pi(features)
|
||||
mean_actions = act.mu(latent_pi)
|
||||
mean_actions = act.mu_net(latent_pi)
|
||||
|
||||
# TODO: Allow contextual covariance with sde
|
||||
if self.use_sde:
|
||||
log_std = act.log_std
|
||||
chol = act.chol
|
||||
else:
|
||||
# Unstructured exploration (Original implementation)
|
||||
log_std = act.log_std(latent_pi)
|
||||
chol = act.chol_net(latent_pi)
|
||||
# Original Implementation to cap the standard deviation
|
||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||
chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
|
||||
act.chol = chol
|
||||
|
||||
act_dist = self.action_dist
|
||||
act_dist = self.actor.action_dist
|
||||
# internal A
|
||||
if self.use_sde:
|
||||
actions_pi = self.actions_from_params(
|
||||
mean_actions, log_std, latent_pi) # latent_pi = latent_sde
|
||||
mean_actions, chol, latent_pi) # latent_pi = latent_sde
|
||||
else:
|
||||
actions_pi = act_dist.actions_from_params(
|
||||
mean_actions, log_std)
|
||||
mean_actions, chol)
|
||||
|
||||
p_dist = self.action_dist.distribution
|
||||
q_dist = new_dist_like(
|
||||
p_dist, replay_data.means, replay_data.stds)
|
||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||
p_dist = act_dist.distribution
|
||||
# q_dist = new_dist_like(
|
||||
# p_dist, replay_data.means, replay_data.stds)
|
||||
#proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||
proj_p = p_dist
|
||||
log_prob = proj_p.log_prob(actions_pi)
|
||||
log_prob = log_prob.reshape(-1, 1)
|
||||
|
||||
@ -368,6 +375,23 @@ class TRL_SAC(GaussianRolloutCollectorAuxclass, OffPolicyAlgorithm):
|
||||
if len(ent_coef_losses) > 0:
|
||||
self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
pol = self.policy.actor
|
||||
|
||||
if hasattr(pol, "log_std"):
|
||||
self.logger.record(
|
||||
"train/std", th.exp(pol.log_std).mean().item())
|
||||
elif hasattr(pol, "chol"):
|
||||
chol = pol.chol
|
||||
if len(chol.shape) == 1:
|
||||
self.logger.record(
|
||||
"train/std", th.mean(chol).mean().item())
|
||||
elif len(chol.shape) == 2:
|
||||
self.logger.record(
|
||||
"train/std", th.mean(th.sqrt(th.diagonal(chol.T @ chol, dim1=-2, dim2=-1))).mean().item())
|
||||
else:
|
||||
self.logger.record(
|
||||
"train/std", th.mean(th.sqrt(th.diagonal(chol.mT @ chol, dim1=-2, dim2=-1))).mean().item())
|
||||
|
||||
def learn(
|
||||
self,
|
||||
total_timesteps: int,
|
||||
|
Loading…
Reference in New Issue
Block a user