Support clip_range None

This commit is contained in:
Dominik Moritz Roth 2022-08-28 02:07:18 +02:00
parent 1d3c2fe005
commit eb881559d6

View File

@ -192,6 +192,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
super()._setup_model() super()._setup_model()
# Initialize schedules for policy/value clipping # Initialize schedules for policy/value clipping
if self.clip_range is not None:
self.clip_range = get_schedule_fn(self.clip_range) self.clip_range = get_schedule_fn(self.clip_range)
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)): if isinstance(self.clip_range_vf, (float, int)):
@ -208,7 +209,10 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
# Update optimizer learning rate # Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer) self._update_learning_rate(self.policy.optimizer)
# Compute current clip range # Compute current clip range
if self.clip_range:
clip_range = self.clip_range(self._current_progress_remaining) clip_range = self.clip_range(self._current_progress_remaining)
else:
clip_range = None
# Optional: clip range for the value function # Optional: clip range for the value function
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf( clip_range_vf = self.clip_range_vf(
@ -276,7 +280,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss # clipped surrogate loss
if self.clip_range is None: if self.clip_range is None:
surrogate_loss = advantages * ratio surrogate_loss = -(advantages * ratio).mean()
else: else:
surrogate_loss_1 = advantages * ratio surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * \ surrogate_loss_2 = advantages * \
@ -286,7 +290,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
surrogate_losses.append(surrogate_loss.item()) surrogate_losses.append(surrogate_loss.item())
if self.clip_range is None: if clip_range is None:
clip_fraction = 0 clip_fraction = 0
else: else:
clip_fraction = th.mean( clip_fraction = th.mean(