Action-Loss for SAC
This commit is contained in:
parent
3110275d7b
commit
f9a08add40
@ -122,6 +122,7 @@ class SAC(OffPolicyAlgorithm):
|
||||
use_sde_at_warmup: bool = False,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
action_coef: float = 0.0,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
@ -176,6 +177,8 @@ class SAC(OffPolicyAlgorithm):
|
||||
self.n_steps = buffer_size
|
||||
self.gae_lambda = False
|
||||
|
||||
self.action_coef = action_coef
|
||||
|
||||
if projection != None:
|
||||
print('[!] An projection was supplied! Will be ignored!')
|
||||
|
||||
@ -236,6 +239,7 @@ class SAC(OffPolicyAlgorithm):
|
||||
|
||||
ent_coef_losses, ent_coefs = [], []
|
||||
actor_losses, critic_losses = [], []
|
||||
action_losses = []
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
@ -330,6 +334,11 @@ class SAC(OffPolicyAlgorithm):
|
||||
|
||||
projection_loss = th.zeros(1)
|
||||
|
||||
# 'Principle of least action'
|
||||
action_loss = th.mean(th.square(actions_pi))
|
||||
|
||||
action_losses.append(action_loss.item())
|
||||
|
||||
# Compute critic loss
|
||||
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
|
||||
for current_q in current_q_values)
|
||||
@ -347,7 +356,8 @@ class SAC(OffPolicyAlgorithm):
|
||||
q_values_pi = th.cat(self.critic(
|
||||
replay_data.observations, actions_pi), dim=1)
|
||||
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi +
|
||||
self.action_coef * action_loss).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
@ -364,6 +374,7 @@ class SAC(OffPolicyAlgorithm):
|
||||
|
||||
self.logger.record("train/n_updates",
|
||||
self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/action_loss", np.mean(action_losses))
|
||||
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
|
Loading…
Reference in New Issue
Block a user