Incremental progress at implementing trl_pg

This commit is contained in:
Dominik Moritz Roth 2022-06-25 13:58:58 +02:00
parent 941b7347f1
commit 25316ec0b8

View File

@ -61,7 +61,7 @@ class TRL_PG(OnPolicyAlgorithm):
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
#TODO: Add new params to doc
# TODO: Add new params to doc
policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": ActorCriticPolicy,
@ -97,7 +97,7 @@ class TRL_PG(OnPolicyAlgorithm):
# Different from PPO:
#projection: BaseProjectionLayer = None,
projection = None,
projection=None,
_init_setup_model: bool = True,
):
@ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm):
clip_range = self.clip_range(self._current_progress_remaining)
# Optional: clip range for the value function
if self.clip_range_vf is not None:
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
clip_range_vf = self.clip_range_vf(
self._current_progress_remaining)
surrogate_losses = []
entropy_losses = []
@ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm):
# here we go:
pol = self.policy
feat = pol.extract_features(rollout_data.observations)
features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features)
p = pol._get_action_dist_from_latent(latent_pi)
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions)
# TODO: define b_q and global_step
proj_p = self.projection(pol, p, b_q, self._global_step)
log_prob = proj_p.log_prob(actions)
# or log_prob = pol.log_probability(proj_p, actions)
values = self.value_net(latent_vf)
entropy = proj_p.entropy() # or not...
entropy = proj_p.entropy() # or not...
values = values.flatten()
# Normalize advantage
advantages = rollout_data.advantages
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
advantages = (advantages - advantages.mean()
) / (advantages.std() + 1e-8)
# ratio between old and new policy, should be one at the first iteration
ratio = th.exp(log_prob - rollout_data.old_log_prob)
@ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm):
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
# clipped surrogate loss
surrogate_loss_1 = advantages * ratio
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
surrogate_loss_2 = advantages * \
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
surrogate_loss = - \
th.min(surrogate_loss_1, surrogate_loss_2).mean()
surrogate_losses.append(surrogate_loss.item())
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
clip_fraction = th.mean(
(th.abs(ratio - 1) > clip_range).float()).item()
clip_fractions.append(clip_fraction)
if self.clip_range_vf is None:
@ -285,8 +291,9 @@ class TRL_PG(OnPolicyAlgorithm):
entropy_losses.append(entropy_loss.item())
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
#trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
trust_region_loss = th.zeros(
1, device=entropy_loss.device) # TODO: Implement
trust_region_losses.append(trust_region_loss.item())
@ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm):
# and Schulman blog: http://joschu.net/blog/kl-approx.html
with th.no_grad():
log_ratio = log_prob - rollout_data.old_log_prob
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_div = th.mean(
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
approx_kl_divs.append(approx_kl_div)
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
continue_training = False
if self.verbose >= 1:
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
print(
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
break
# Optimization step
self.policy.optimizer.zero_grad()
loss.backward()
# Clip grad norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
th.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
if not continue_training:
break
self._n_updates += self.n_epochs
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
explained_var = explained_variance(
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
# Logs
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
self.logger.record("train/trust_region_loss",
np.mean(trust_region_losses))
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
self.logger.record("train/value_loss", np.mean(value_losses))
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
@ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm):
self.logger.record("train/loss", loss.item())
self.logger.record("train/explained_variance", explained_var)
if hasattr(self.policy, "log_std"):
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record(
"train/std", th.exp(self.policy.log_std).mean().item())
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/n_updates",
self._n_updates, exclude="tensorboard")
self.logger.record("train/clip_range", clip_range)
if self.clip_range_vf is not None:
self.logger.record("train/clip_range_vf", clip_range_vf)