Incremental progress at implementing trl_pg
This commit is contained in:
parent
941b7347f1
commit
25316ec0b8
@ -61,7 +61,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
#TODO: Add new params to doc
|
||||
# TODO: Add new params to doc
|
||||
|
||||
policy_aliases: Dict[str, Type[BasePolicy]] = {
|
||||
"MlpPolicy": ActorCriticPolicy,
|
||||
@ -97,7 +97,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
|
||||
# Different from PPO:
|
||||
#projection: BaseProjectionLayer = None,
|
||||
projection = None,
|
||||
projection=None,
|
||||
|
||||
_init_setup_model: bool = True,
|
||||
):
|
||||
@ -190,7 +190,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
clip_range_vf = self.clip_range_vf(
|
||||
self._current_progress_remaining)
|
||||
|
||||
surrogate_losses = []
|
||||
entropy_losses = []
|
||||
@ -233,20 +234,22 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
|
||||
# here we go:
|
||||
pol = self.policy
|
||||
feat = pol.extract_features(rollout_data.observations)
|
||||
features = pol.extract_features(rollout_data.observations)
|
||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||
p = pol._get_action_dist_from_latent(latent_pi)
|
||||
proj_p = self.projection(pol, p, b_q, self._global_step) # TODO: define b_q and global_step
|
||||
log_prob = proj_p.log_prob(actions) # or log_prob = pol.log_probability(proj_p, actions)
|
||||
# TODO: define b_q and global_step
|
||||
proj_p = self.projection(pol, p, b_q, self._global_step)
|
||||
log_prob = proj_p.log_prob(actions)
|
||||
# or log_prob = pol.log_probability(proj_p, actions)
|
||||
values = self.value_net(latent_vf)
|
||||
entropy = proj_p.entropy() # or not...
|
||||
|
||||
|
||||
values = values.flatten()
|
||||
# Normalize advantage
|
||||
advantages = rollout_data.advantages
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
advantages = (advantages - advantages.mean()
|
||||
) / (advantages.std() + 1e-8)
|
||||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
@ -254,12 +257,15 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
# Difference from PPO: We renamed 'policy_loss' to 'surrogate_loss'
|
||||
# clipped surrogate loss
|
||||
surrogate_loss_1 = advantages * ratio
|
||||
surrogate_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
surrogate_loss = -th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||
surrogate_loss_2 = advantages * \
|
||||
th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
surrogate_loss = - \
|
||||
th.min(surrogate_loss_1, surrogate_loss_2).mean()
|
||||
|
||||
surrogate_losses.append(surrogate_loss.item())
|
||||
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fraction = th.mean(
|
||||
(th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
@ -285,8 +291,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||
#trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
||||
trust_region_loss = th.zeros(1, device=entropy_loss.device) # TODO: Implement
|
||||
# trust_region_loss = self.projection.get_trust_region_loss()#TODO: params
|
||||
trust_region_loss = th.zeros(
|
||||
1, device=entropy_loss.device) # TODO: Implement
|
||||
|
||||
trust_region_losses.append(trust_region_loss.item())
|
||||
|
||||
@ -301,32 +308,37 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
# and Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
with th.no_grad():
|
||||
log_ratio = log_prob - rollout_data.old_log_prob
|
||||
approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_div = th.mean(
|
||||
(th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
|
||||
approx_kl_divs.append(approx_kl_div)
|
||||
|
||||
if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
|
||||
continue_training = False
|
||||
if self.verbose >= 1:
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
print(
|
||||
f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
|
||||
break
|
||||
|
||||
# Optimization step
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip grad norm
|
||||
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
th.nn.utils.clip_grad_norm_(
|
||||
self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
if not continue_training:
|
||||
break
|
||||
|
||||
self._n_updates += self.n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||
explained_var = explained_variance(
|
||||
self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())
|
||||
|
||||
# Logs
|
||||
self.logger.record("train/surrogate_loss", np.mean(surrogate_losses))
|
||||
self.logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||
self.logger.record("train/trust_region_loss", np.mean(trust_region_losses))
|
||||
self.logger.record("train/trust_region_loss",
|
||||
np.mean(trust_region_losses))
|
||||
self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
self.logger.record("train/value_loss", np.mean(value_losses))
|
||||
self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
@ -334,9 +346,11 @@ class TRL_PG(OnPolicyAlgorithm):
|
||||
self.logger.record("train/loss", loss.item())
|
||||
self.logger.record("train/explained_variance", explained_var)
|
||||
if hasattr(self.policy, "log_std"):
|
||||
self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
self.logger.record(
|
||||
"train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/n_updates",
|
||||
self._n_updates, exclude="tensorboard")
|
||||
self.logger.record("train/clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
self.logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
Loading…
Reference in New Issue
Block a user