Allow completely disabling some PPO features (for TRPL)

This commit is contained in:
Dominik Moritz Roth 2022-08-28 00:26:44 +02:00
parent afec4e709c
commit 1d3c2fe005

View File

@ -94,12 +94,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
n_epochs: int = 10, n_epochs: int = 10,
gamma: float = 0.99, gamma: float = 0.99,
gae_lambda: float = 0.95, gae_lambda: float = 0.95,
clip_range: Union[float, Schedule] = 0.2, clip_range: Union[None, float, Schedule] = 0.2,
clip_range_vf: Union[None, float, Schedule] = None, clip_range_vf: Union[None, float, Schedule] = None,
normalize_advantage: bool = True, normalize_advantage: bool = True,
ent_coef: float = 0.0, ent_coef: float = 0.0,
vf_coef: float = 0.5, vf_coef: float = 0.5,
max_grad_norm: float = 0.5, max_grad_norm: Union[None, float] = 0.5,
use_sde: bool = False, use_sde: bool = False,
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
target_kl: Optional[float] = None, target_kl: Optional[float] = None,
@ -275,6 +275,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss # clipped surrogate loss
if self.clip_range is None:
surrogate_loss = advantages * ratio
else:
surrogate_loss_1 = advantages * ratio surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * \ surrogate_loss_2 = advantages * \
th.clamp(ratio, 1 - clip_range, 1 + clip_range) th.clamp(ratio, 1 - clip_range, 1 + clip_range)
@ -283,6 +286,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
surrogate_losses.append(surrogate_loss.item()) surrogate_losses.append(surrogate_loss.item())
if self.clip_range is None:
clip_fraction = 0
else:
clip_fraction = th.mean( clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item() (th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction) clip_fractions.append(clip_fraction)
@ -341,6 +347,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.policy.optimizer.zero_grad() self.policy.optimizer.zero_grad()
loss.backward() loss.backward()
# Clip grad norm # Clip grad norm
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_( th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm) self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step() self.policy.optimizer.step()
@ -380,6 +387,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
self.logger.record("train/n_updates", self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard") self._n_updates, exclude="tensorboard")
if self.clip_range is not None:
self.logger.record("train/clip_range", clip_range) self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)