Implemented action_loss
This commit is contained in:
parent
2f05474091
commit
e4a8cfc349
@ -99,6 +99,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
normalize_advantage: bool = True,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
action_coef: float = 0.0,
|
||||
max_grad_norm: Union[None, float] = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
@ -181,6 +182,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.target_kl = target_kl
|
||||
|
||||
self.action_coef = action_coef
|
||||
|
||||
# Different from PPO:
|
||||
self.projection = projection
|
||||
self._global_steps = 0
|
||||
@ -221,6 +224,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
surrogate_losses = []
|
||||
entropy_losses = []
|
||||
trust_region_losses = []
|
||||
action_losses = []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
@ -325,7 +329,13 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
|
||||
trust_region_losses.append(trust_region_loss.item())
|
||||
|
||||
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss
|
||||
# 'Principle of least action'
|
||||
action_loss = th.square(actions)
|
||||
|
||||
action_losses.append(action_loss)
|
||||
|
||||
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
||||
trust_region_loss + self.action_coef * action_loss
|
||||
pg_losses.append(policy_loss.item())
|
||||
|
||||
loss = policy_loss + self.vf_coef * value_loss
|
||||
@ -369,6 +379,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
self.logger.record("train/trust_region_loss",
|
||||
np.mean(trust_region_losses))
|
||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
self.logger.record("train/action_loss", np.mean(action_losses))
|
||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
self.logger.record("train/clip_fraction", np.mean(clip_fractions))
|
||||
|
Loading…
Reference in New Issue
Block a user