diff --git a/metastable_baselines/sac/sac.py b/metastable_baselines/sac/sac.py index bd3ece5..202d432 100644 --- a/metastable_baselines/sac/sac.py +++ b/metastable_baselines/sac/sac.py @@ -122,6 +122,7 @@ class SAC(OffPolicyAlgorithm): use_sde_at_warmup: bool = False, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, + action_coef: float = 0.0, policy_kwargs: Optional[Dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, @@ -176,6 +177,8 @@ class SAC(OffPolicyAlgorithm): self.n_steps = buffer_size self.gae_lambda = False + self.action_coef = action_coef + if projection != None: print('[!] An projection was supplied! Will be ignored!') @@ -236,6 +239,7 @@ class SAC(OffPolicyAlgorithm): ent_coef_losses, ent_coefs = [], [] actor_losses, critic_losses = [], [] + action_losses = [] for gradient_step in range(gradient_steps): # Sample replay buffer @@ -330,6 +334,11 @@ class SAC(OffPolicyAlgorithm): projection_loss = th.zeros(1) + # 'Principle of least action' + action_loss = th.mean(th.square(actions_pi)) + + action_losses.append(action_loss.item()) + # Compute critic loss critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values) @@ -347,7 +356,8 @@ class SAC(OffPolicyAlgorithm): q_values_pi = th.cat(self.critic( replay_data.observations, actions_pi), dim=1) min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - min_qf_pi).mean() + actor_loss = (ent_coef * log_prob - min_qf_pi + + self.action_coef * action_loss).mean() actor_losses.append(actor_loss.item()) # Optimize the actor @@ -364,6 +374,7 @@ class SAC(OffPolicyAlgorithm): self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/action_loss", np.mean(action_losses)) self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) self.logger.record("train/critic_loss", np.mean(critic_losses))