Action-Loss for SAC
This commit is contained in:
parent
3110275d7b
commit
f9a08add40
@ -122,6 +122,7 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
use_sde_at_warmup: bool = False,
|
use_sde_at_warmup: bool = False,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
create_eval_env: bool = False,
|
create_eval_env: bool = False,
|
||||||
|
action_coef: float = 0.0,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@ -176,6 +177,8 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
self.n_steps = buffer_size
|
self.n_steps = buffer_size
|
||||||
self.gae_lambda = False
|
self.gae_lambda = False
|
||||||
|
|
||||||
|
self.action_coef = action_coef
|
||||||
|
|
||||||
if projection != None:
|
if projection != None:
|
||||||
print('[!] An projection was supplied! Will be ignored!')
|
print('[!] An projection was supplied! Will be ignored!')
|
||||||
|
|
||||||
@ -236,6 +239,7 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
|
|
||||||
ent_coef_losses, ent_coefs = [], []
|
ent_coef_losses, ent_coefs = [], []
|
||||||
actor_losses, critic_losses = [], []
|
actor_losses, critic_losses = [], []
|
||||||
|
action_losses = []
|
||||||
|
|
||||||
for gradient_step in range(gradient_steps):
|
for gradient_step in range(gradient_steps):
|
||||||
# Sample replay buffer
|
# Sample replay buffer
|
||||||
@ -330,6 +334,11 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
|
|
||||||
projection_loss = th.zeros(1)
|
projection_loss = th.zeros(1)
|
||||||
|
|
||||||
|
# 'Principle of least action'
|
||||||
|
action_loss = th.mean(th.square(actions_pi))
|
||||||
|
|
||||||
|
action_losses.append(action_loss.item())
|
||||||
|
|
||||||
# Compute critic loss
|
# Compute critic loss
|
||||||
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
|
critic_loss_raw = 0.5 * sum(F.mse_loss(current_q, target_q_values)
|
||||||
for current_q in current_q_values)
|
for current_q in current_q_values)
|
||||||
@ -347,7 +356,8 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
q_values_pi = th.cat(self.critic(
|
q_values_pi = th.cat(self.critic(
|
||||||
replay_data.observations, actions_pi), dim=1)
|
replay_data.observations, actions_pi), dim=1)
|
||||||
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
|
||||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
actor_loss = (ent_coef * log_prob - min_qf_pi +
|
||||||
|
self.action_coef * action_loss).mean()
|
||||||
actor_losses.append(actor_loss.item())
|
actor_losses.append(actor_loss.item())
|
||||||
|
|
||||||
# Optimize the actor
|
# Optimize the actor
|
||||||
@ -364,6 +374,7 @@ class SAC(OffPolicyAlgorithm):
|
|||||||
|
|
||||||
self.logger.record("train/n_updates",
|
self.logger.record("train/n_updates",
|
||||||
self._n_updates, exclude="tensorboard")
|
self._n_updates, exclude="tensorboard")
|
||||||
|
self.logger.record("train/action_loss", np.mean(action_losses))
|
||||||
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
self.logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||||
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
self.logger.record("train/actor_loss", np.mean(actor_losses))
|
||||||
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
self.logger.record("train/critic_loss", np.mean(critic_losses))
|
||||||
|
Loading…
Reference in New Issue
Block a user