Allow completely disabling some PPO features (for TRPL)
This commit is contained in:
parent
afec4e709c
commit
1d3c2fe005
@ -94,12 +94,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
n_epochs: int = 10,
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95,
|
||||
clip_range: Union[float, Schedule] = 0.2,
|
||||
clip_range: Union[None, float, Schedule] = 0.2,
|
||||
clip_range_vf: Union[None, float, Schedule] = None,
|
||||
normalize_advantage: bool = True,
|
||||
ent_coef: float = 0.0,
|
||||
vf_coef: float = 0.5,
|
||||
max_grad_norm: float = 0.5,
|
||||
max_grad_norm: Union[None, float] = 0.5,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
target_kl: Optional[float] = None,
|
||||
@ -275,6 +275,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
|
||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||
# clipped surrogate loss
|
||||
if self.clip_range is None:
|
||||
surrogate_loss = advantages * ratio
|
||||
else:
|
||||
surrogate_loss_1 = advantages * ratio
|
||||
surrogate_loss_2 = advantages * \
|
||||
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
@ -283,6 +286,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
|
||||
surrogate_losses.append(surrogate_loss.item())
|
||||
|
||||
if self.clip_range is None:
|
||||
clip_fraction = 0
|
||||
else:
|
||||
clip_fraction = th.mean(
|
||||
(th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
@ -341,6 +347,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
if self.max_grad_norm is not None:
|
||||
th.nn.utils.clip_grad_norm_(
|
||||
self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
@ -380,6 +387,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
|
||||
self.logger.record("train/n_updates",
|
||||
self._n_updates, exclude="tensorboard")
|
||||
if self.clip_range is not None:
|
||||
self.logger.record("train/clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
Loading…
Reference in New Issue
Block a user