Support clip_range None
This commit is contained in:
parent
1d3c2fe005
commit
eb881559d6
@ -192,6 +192,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
super()._setup_model()
|
super()._setup_model()
|
||||||
|
|
||||||
# Initialize schedules for policy/value clipping
|
# Initialize schedules for policy/value clipping
|
||||||
|
if self.clip_range is not None:
|
||||||
self.clip_range = get_schedule_fn(self.clip_range)
|
self.clip_range = get_schedule_fn(self.clip_range)
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
if isinstance(self.clip_range_vf, (float, int)):
|
if isinstance(self.clip_range_vf, (float, int)):
|
||||||
@ -208,7 +209,10 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
# Update optimizer learning rate
|
# Update optimizer learning rate
|
||||||
self._update_learning_rate(self.policy.optimizer)
|
self._update_learning_rate(self.policy.optimizer)
|
||||||
# Compute current clip range
|
# Compute current clip range
|
||||||
|
if self.clip_range:
|
||||||
clip_range = self.clip_range(self._current_progress_remaining)
|
clip_range = self.clip_range(self._current_progress_remaining)
|
||||||
|
else:
|
||||||
|
clip_range = None
|
||||||
# Optional: clip range for the value function
|
# Optional: clip range for the value function
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
clip_range_vf = self.clip_range_vf(
|
clip_range_vf = self.clip_range_vf(
|
||||||
@ -276,7 +280,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
if self.clip_range is None:
|
if self.clip_range is None:
|
||||||
surrogate_loss = advantages * ratio
|
surrogate_loss = -(advantages * ratio).mean()
|
||||||
else:
|
else:
|
||||||
surrogate_loss_1 = advantages * ratio
|
surrogate_loss_1 = advantages * ratio
|
||||||
surrogate_loss_2 = advantages * \
|
surrogate_loss_2 = advantages * \
|
||||||
@ -286,7 +290,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
|
|
||||||
surrogate_losses.append(surrogate_loss.item())
|
surrogate_losses.append(surrogate_loss.item())
|
||||||
|
|
||||||
if self.clip_range is None:
|
if clip_range is None:
|
||||||
clip_fraction = 0
|
clip_fraction = 0
|
||||||
else:
|
else:
|
||||||
clip_fraction = th.mean(
|
clip_fraction = th.mean(
|
||||||
|
Loading…
Reference in New Issue
Block a user