Allow completely disabling some PPO features (for TRPL)

This commit is contained in:
Dominik Moritz Roth 2022-08-28 00:26:44 +02:00
parent afec4e709c
commit 1d3c2fe005

View File

@ -94,12 +94,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
n_epochs: int = 10,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2,
clip_range: Union[None, float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True,
ent_coef: float = 0.0,
vf_coef: float = 0.5,
max_grad_norm: float = 0.5,
max_grad_norm: Union[None, float] = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
target_kl: Optional[float] = None,
@ -275,16 +275,22 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss
surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * \
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = - \
th.min(surrogate_loss_1, surrogate_loss_2).mean()
if self.clip_range is None:
surrogate_loss = advantages * ratio
else:
surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * \
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = - \
th.min(surrogate_loss_1, surrogate_loss_2).mean()
surrogate_losses.append(surrogate_loss.item())
clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item()
if self.clip_range is None:
clip_fraction = 0
else:
clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
@ -341,8 +347,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm)
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
@ -380,7 +387,8 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range is not None:
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)