From 479d73ac4b203deeeabce38ff6992a028baa4ac6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 3 Nov 2022 20:13:36 +0100 Subject: [PATCH] Hotfix for exploding gradients --- metastable_baselines/ppo/ppo.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 75f843f..db37a33 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -25,6 +25,8 @@ from metastable_projections.projections.kl_projection_layer import KLProjectionL from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass +from copy import deepcopy + class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): """ @@ -228,10 +230,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): pg_losses, value_losses = [], [] clip_fractions = [] + setbackCtr = 0 + bak = deepcopy(self.policy.state_dict()) + continue_training = True # train for n_epochs epochs for epoch in range(self.n_epochs): + # self.policy.load_state_dict( approx_kl_divs = [] # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): @@ -253,7 +259,16 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): pol = self.policy features = pol.extract_features(rollout_data.observations) latent_pi, latent_vf = pol.mlp_extractor(features) - p = pol._get_action_dist_from_latent(latent_pi) + try: + p = pol._get_action_dist_from_latent(latent_pi) + except ValueError: + self.policy.load_state_dict(bak) + setbackCtr += 1 + print( + '[!] Gradients Exploded; reseting to last known states (setback number '+str(setbackCtr)+')') + break + del bak + bak = deepcopy(self.policy.state_dict()) p_dist = p.distribution if isinstance(self.projection, WassersteinProjectionLayer): q_dist = new_dist_like_from_sqrt(