Action-Loss for SAC

This commit is contained in:
Dominik Moritz Roth 2022-10-08 18:33:37 +02:00
parent 3110275d7b
commit f9a08add40

View File

@ -122,6 +122,7 @@ class SAC(OffPolicyAlgorithm):
use_sde_at_warmup: bool = False, use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None, tensorboard_log: Optional[str] = None,
create_eval_env: bool = False, create_eval_env: bool = False,
action_coef: float = 0.0,
policy_kwargs: Optional[Dict[str, Any]] = None, policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0, verbose: int = 0,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -176,6 +177,8 @@ class SAC(OffPolicyAlgorithm):
self.n_steps = buffer_size self.n_steps = buffer_size
self.gae_lambda = False self.gae_lambda = False
self.action_coef = action_coef
if projection != None: if projection != None:
print('[!] An projection was supplied! Will be ignored!') print('[!] An projection was supplied! Will be ignored!')
@ -236,6 +239,7 @@ class SAC(OffPolicyAlgorithm):
ent_coef_losses, ent_coefs = [], [] ent_coef_losses, ent_coefs = [], []
actor_losses, critic_losses = [], [] actor_losses, critic_losses = [], []
action_losses = []
for gradient_step in range(gradient_steps): for gradient_step in range(gradient_steps):
# Sample replay buffer # Sample replay buffer
@ -330,6 +334,11 @@ class SAC(OffPolicyAlgorithm):
projection_loss = th.zeros(1) projection_loss = th.zeros(1)
# 'Principle of least action'
action_loss = th.mean(th.square(actions_pi))
action_losses.append(action_loss.item())
# Compute critic loss # Compute critic loss
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values) critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
for current_q in current_q_values) for current_q in current_q_values)
@ -347,7 +356,8 @@ class SAC(OffPolicyAlgorithm):
q_values_pi = th.cat(self.critic( q_values_pi = th.cat(self.critic(
replay_data.observations, actions_pi), dim=1) replay_data.observations, actions_pi), dim=1)
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - min_qf_pi).mean() actor_loss = (ent_coef * log_prob - min_qf_pi +
self.action_coef * action_loss).mean()
actor_losses.append(actor_loss.item()) actor_losses.append(actor_loss.item())
# Optimize the actor # Optimize the actor
@ -364,6 +374,7 @@ class SAC(OffPolicyAlgorithm):
self.logger.record("train/n_updates", self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard") self._n_updates, exclude="tensorboard")
self.logger.record("train/action_loss", np.mean(action_losses))
self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/actor_loss", np.mean(actor_losses))
self.logger.record("train/critic_loss", np.mean(critic_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses))