Support clip_range None
This commit is contained in:
parent
1d3c2fe005
commit
eb881559d6
@ -192,6 +192,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
super()._setup_model()
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
if self.clip_range is not None:
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
if isinstance(self.clip_range_vf, (float, int)):
|
||||
@ -208,7 +209,10 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
if self.clip_range:
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
else:
|
||||
clip_range = None
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(
|
||||
@ -276,7 +280,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||
# clipped surrogate loss
|
||||
if self.clip_range is None:
|
||||
surrogate_loss = advantages * ratio
|
||||
surrogate_loss = -(advantages * ratio).mean()
|
||||
else:
|
||||
surrogate_loss_1 = advantages * ratio
|
||||
surrogate_loss_2 = advantages * \
|
||||
@ -286,7 +290,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
|
||||
surrogate_losses.append(surrogate_loss.item())
|
||||
|
||||
if self.clip_range is None:
|
||||
if clip_range is None:
|
||||
clip_fraction = 0
|
||||
else:
|
||||
clip_fraction = th.mean(
|
||||
|
Loading…
Reference in New Issue
Block a user