Allow completely disabling some PPO features (for TRPL)
This commit is contained in:
parent
afec4e709c
commit
1d3c2fe005
@ -94,12 +94,12 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
n_epochs: int = 10,
|
n_epochs: int = 10,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
gae_lambda: float = 0.95,
|
gae_lambda: float = 0.95,
|
||||||
clip_range: Union[float, Schedule] = 0.2,
|
clip_range: Union[None, float, Schedule] = 0.2,
|
||||||
clip_range_vf: Union[None, float, Schedule] = None,
|
clip_range_vf: Union[None, float, Schedule] = None,
|
||||||
normalize_advantage: bool = True,
|
normalize_advantage: bool = True,
|
||||||
ent_coef: float = 0.0,
|
ent_coef: float = 0.0,
|
||||||
vf_coef: float = 0.5,
|
vf_coef: float = 0.5,
|
||||||
max_grad_norm: float = 0.5,
|
max_grad_norm: Union[None, float] = 0.5,
|
||||||
use_sde: bool = False,
|
use_sde: bool = False,
|
||||||
sde_sample_freq: int = -1,
|
sde_sample_freq: int = -1,
|
||||||
target_kl: Optional[float] = None,
|
target_kl: Optional[float] = None,
|
||||||
@ -275,6 +275,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
|
if self.clip_range is None:
|
||||||
|
surrogate_loss = advantages * ratio
|
||||||
|
else:
|
||||||
surrogate_loss_1 = advantages * ratio
|
surrogate_loss_1 = advantages * ratio
|
||||||
surrogate_loss_2 = advantages * \
|
surrogate_loss_2 = advantages * \
|
||||||
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||||
@ -283,6 +286,9 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
|
|
||||||
surrogate_losses.append(surrogate_loss.item())
|
surrogate_losses.append(surrogate_loss.item())
|
||||||
|
|
||||||
|
if self.clip_range is None:
|
||||||
|
clip_fraction = 0
|
||||||
|
else:
|
||||||
clip_fraction = th.mean(
|
clip_fraction = th.mean(
|
||||||
(th.abs(ratio - 1) > clip_range).float()).item()
|
(th.abs(ratio - 1) > clip_range).float()).item()
|
||||||
clip_fractions.append(clip_fraction)
|
clip_fractions.append(clip_fraction)
|
||||||
@ -341,6 +347,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
self.policy.optimizer.zero_grad()
|
self.policy.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# Clip grad norm
|
# Clip grad norm
|
||||||
|
if self.max_grad_norm is not None:
|
||||||
th.nn.utils.clip_grad_norm_(
|
th.nn.utils.clip_grad_norm_(
|
||||||
self.policy.parameters(), self.max_grad_norm)
|
self.policy.parameters(), self.max_grad_norm)
|
||||||
self.policy.optimizer.step()
|
self.policy.optimizer.step()
|
||||||
@ -380,6 +387,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
|
|
||||||
self.logger.record("train/n_updates",
|
self.logger.record("train/n_updates",
|
||||||
self._n_updates, exclude="tensorboard")
|
self._n_updates, exclude="tensorboard")
|
||||||
|
if self.clip_range is not None:
|
||||||
self.logger.record("train/clip_range", clip_range)
|
self.logger.record("train/clip_range", clip_range)
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
Loading…
Reference in New Issue
Block a user