Incremental progress at implementing trl_pg
This commit is contained in:
parent
941b7347f1
commit
25316ec0b8
@ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
clip_range = self.clip_range(self._current_progress_remaining)
|
clip_range = self.clip_range(self._current_progress_remaining)
|
||||||
# Optional: clip range for the value function
|
# Optional: clip range for the value function
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
clip_range_vf = self.clip_range_vf(
|
||||||
|
self._current_progress_remaining)
|
||||||
|
|
||||||
surrogate_losses = []
|
surrogate_losses = []
|
||||||
entropy_losses = []
|
entropy_losses = []
|
||||||
@ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# here we go:
|
# here we go:
|
||||||
pol = self.policy
|
pol = self.policy
|
||||||
feat = pol.extract_features(rollout_data.observations)
|
features = pol.extract_features(rollout_data.observations)
|
||||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step
|
# TODO: define b_q and global_step
|
||||||
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions)
|
proj_p = self.projection(pol, p, b_q, self._global_step)
|
||||||
|
log_prob = proj_p.log_prob(actions)
|
||||||
|
# or log_prob = pol.log_probability(proj_p, actions)
|
||||||
values = self.value_net(latent_vf)
|
values = self.value_net(latent_vf)
|
||||||
entropy = proj_p.entropy() # or not...
|
entropy = proj_p.entropy() # or not...
|
||||||
|
|
||||||
|
|
||||||
values = values.flatten()
|
values = values.flatten()
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
advantages = rollout_data.advantages
|
advantages = rollout_data.advantages
|
||||||
if self.normalize_advantage:
|
if self.normalize_advantage:
|
||||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
advantages = (advantages - advantages.mean()
|
||||||
|
) / (advantages.std() + 1e-8)
|
||||||
|
|
||||||
# ratio between old and new policy, should be one at the first iteration
|
# ratio between old and new policy, should be one at the first iteration
|
||||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||||
@ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||||
# clipped surrogate loss
|
# clipped surrogate loss
|
||||||
surrogate_loss_1 = advantages * ratio
|
surrogate_loss_1 = advantages * ratio
|
||||||
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
surrogate_loss_2 = advantages * \
|
||||||
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||||
|
surrogate_loss = - \
|
||||||
|
th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||||
|
|
||||||
surrogate_losses.append(surrogate_loss.item())
|
surrogate_losses.append(surrogate_loss.item())
|
||||||
|
|
||||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
clip_fraction = th.mean(
|
||||||
|
(th.abs(ratio - 1) > clip_range).float()).item()
|
||||||
clip_fractions.append(clip_fraction)
|
clip_fractions.append(clip_fraction)
|
||||||
|
|
||||||
if self.clip_range_vf is None:
|
if self.clip_range_vf is None:
|
||||||
@ -286,7 +292,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||||
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
||||||
trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement
|
trust_region_loss = th.zeros(
|
||||||
|
1, device=entropy_loss.device) # TODO: Implement
|
||||||
|
|
||||||
trust_region_losses.append(trust_region_loss.item())
|
trust_region_losses.append(trust_region_loss.item())
|
||||||
|
|
||||||
@ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
log_ratio = log_prob - rollout_data.old_log_prob
|
log_ratio = log_prob - rollout_data.old_log_prob
|
||||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
approx_kl_div = th.mean(
|
||||||
|
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||||
approx_kl_divs.append(approx_kl_div)
|
approx_kl_divs.append(approx_kl_div)
|
||||||
|
|
||||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||||
continue_training = False
|
continue_training = False
|
||||||
if self.verbose >= 1:
|
if self.verbose >= 1:
|
||||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
print(
|
||||||
|
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Optimization step
|
# Optimization step
|
||||||
self.policy.optimizer.zero_grad()
|
self.policy.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
# Clip grad norm
|
# Clip grad norm
|
||||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
th.nn.utils.clip_grad_norm_(
|
||||||
|
self.policy.parameters(), self.max_grad_norm)
|
||||||
self.policy.optimizer.step()
|
self.policy.optimizer.step()
|
||||||
|
|
||||||
if not continue_training:
|
if not continue_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
self._n_updates += self.n_epochs
|
self._n_updates += self.n_epochs
|
||||||
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
explained_var = explained_variance(
|
||||||
|
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||||
|
|
||||||
# Logs
|
# Logs
|
||||||
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
|
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
|
||||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||||
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
|
self.logger.record("train/trust_region_loss",
|
||||||
|
np.mean(trust_region_losses))
|
||||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||||
@ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
self.logger.record("train/loss", loss.item())
|
self.logger.record("train/loss", loss.item())
|
||||||
self.logger.record("train/explained_variance", explained_var)
|
self.logger.record("train/explained_variance", explained_var)
|
||||||
if hasattr(self.policy, "log_std"):
|
if hasattr(self.policy, "log_std"):
|
||||||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
self.logger.record(
|
||||||
|
"train/std", th.exp(self.policy.log_std).mean().item())
|
||||||
|
|
||||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
self.logger.record("train/n_updates",
|
||||||
|
self._n_updates, exclude="tensorboard")
|
||||||
self.logger.record("train/clip_range", clip_range)
|
self.logger.record("train/clip_range", clip_range)
|
||||||
if self.clip_range_vf is not None:
|
if self.clip_range_vf is not None:
|
||||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||||
|
Loading…
Reference in New Issue
Block a user