Implemented action_loss

This commit is contained in:
Dominik Moritz Roth 2022-09-03 11:16:29 +02:00
parent 2f05474091
commit e4a8cfc349

View File

@ -99,6 +99,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
normalize_advantage: bool = True, normalize_advantage: bool = True,
ent_coef: float = 0.0, ent_coef: float = 0.0,
vf_coef: float = 0.5, vf_coef: float = 0.5,
action_coef: float = 0.0,
max_grad_norm: Union[None, float] = 0.5, max_grad_norm: Union[None, float] = 0.5,
use_sde: bool = False, use_sde: bool = False,
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
@ -181,6 +182,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.normalize_advantage = normalize_advantage self.normalize_advantage = normalize_advantage
self.target_kl = target_kl self.target_kl = target_kl
self.action_coef = action_coef
# Different from PPO: # Different from PPO:
self.projection = projection self.projection = projection
self._global_steps = 0 self._global_steps = 0
@ -221,6 +224,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
surrogate_losses = [] surrogate_losses = []
entropy_losses = [] entropy_losses = []
trust_region_losses = [] trust_region_losses = []
action_losses = []
pg_losses, value_losses = [], [] pg_losses, value_losses = [], []
clip_fractions = [] clip_fractions = []
@ -325,7 +329,13 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
trust_region_losses.append(trust_region_loss.item()) trust_region_losses.append(trust_region_loss.item())
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + trust_region_loss # 'Principle of least action'
action_loss = th.square(actions)
action_losses.append(action_loss)
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
trust_region_loss + self.action_coef * action_loss
pg_losses.append(policy_loss.item()) pg_losses.append(policy_loss.item())
loss = policy_loss + self.vf_coef * value_loss loss = policy_loss + self.vf_coef * value_loss
@ -369,6 +379,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.logger.record("train/trust_region_loss", self.logger.record("train/trust_region_loss",
np.mean(trust_region_losses)) np.mean(trust_region_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/action_loss", np.mean(action_losses))
self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
self.logger.record("train/clip_fraction", np.mean(clip_fractions)) self.logger.record("train/clip_fraction", np.mean(clip_fractions))