Incremental progress at implementing trl_pg

This commit is contained in:
Dominik Moritz Roth 2022-06-25 13:58:58 +02:00
parent 941b7347f1
commit 25316ec0b8

View File

@ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm):
clip_range = self.clip_range(self._current_progress_remaining) clip_range = self.clip_range(self._current_progress_remaining)
# Optional: clip range for the value function # Optional: clip range for the value function
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining) clip_range_vf = self.clip_range_vf(
self._current_progress_remaining)
surrogate_losses = [] surrogate_losses = []
entropy_losses = [] entropy_losses = []
@ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm):
# here we go: # here we go:
pol = self.policy pol = self.policy
feat = pol.extract_features(rollout_data.observations) features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features) latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi) p = pol._get_action_dist_from_latent(latent_pi)
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step # TODO: define b_q and global_step
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions) proj_p = self.projection(pol, p, b_q, self._global_step)
log_prob = proj_p.log_prob(actions)
# or log_prob = pol.log_probability(proj_p, actions)
values = self.value_net(latent_vf) values = self.value_net(latent_vf)
entropy = proj_p.entropy() # or not... entropy = proj_p.entropy() # or not...
values = values.flatten() values = values.flatten()
# Normalize advantage # Normalize advantage
advantages = rollout_data.advantages advantages = rollout_data.advantages
if self.normalize_advantage: if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) advantages = (advantages - advantages.mean()
) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration # ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob) ratio = th.exp(log_prob - rollout_data.old_log_prob)
@ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm):
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss' # Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss # clipped surrogate loss
surrogate_loss_1 = advantages * ratio surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) surrogate_loss_2 = advantages * \
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean() th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = - \
th.min(surrogate_loss_1, surrogate_loss_2).mean()
surrogate_losses.append(surrogate_loss.item()) surrogate_losses.append(surrogate_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction) clip_fractions.append(clip_fraction)
if self.clip_range_vf is None: if self.clip_range_vf is None:
@ -286,7 +292,8 @@ class TRL_PG(OnPolicyAlgorithm):
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss # Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params # trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement trust_region_loss = th.zeros(
1, device=entropy_loss.device) # TODO: Implement
trust_region_losses.append(trust_region_loss.item()) trust_region_losses.append(trust_region_loss.item())
@ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm):
# and Schulman blog: http://joschu.net/blog/kl-approx.html # and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad(): with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() approx_kl_div = th.mean(
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div) approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False continue_training = False
if self.verbose >= 1: if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") print(
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break break
# Optimization step # Optimization step
self.policy.optimizer.zero_grad() self.policy.optimizer.zero_grad()
loss.backward() loss.backward()
# Clip grad norm # Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step() self.policy.optimizer.step()
if not continue_training: if not continue_training:
break break
self._n_updates += self.n_epochs self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) explained_var = explained_variance(
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs # Logs
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses)) self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
self.logger.record("train/entropy_loss", np.mean(entropy_losses)) self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses)) self.logger.record("train/trust_region_loss",
np.mean(trust_region_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
@ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm):
self.logger.record("train/loss", loss.item()) self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var) self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"): if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) self.logger.record(
"train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range) self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None: if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf) self.logger.record("train/clip_range_vf", clip_range_vf)