Hotfix for exploding gradients

This commit is contained in:
Dominik Moritz Roth 2022-11-03 20:13:36 +01:00
parent 82a174122a
commit 479d73ac4b

View File

@ -25,6 +25,8 @@ from metastable_projections.projections.kl_projection_layer import KLProjectionL
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
from copy import deepcopy
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
""" """
@ -228,10 +230,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
pg_losses, value_losses = [], [] pg_losses, value_losses = [], []
clip_fractions = [] clip_fractions = []
setbackCtr = 0
bak = deepcopy(self.policy.state_dict())
continue_training = True continue_training = True
# train for n_epochs epochs # train for n_epochs epochs
for epoch in range(self.n_epochs): for epoch in range(self.n_epochs):
# self.policy.load_state_dict(
approx_kl_divs = [] approx_kl_divs = []
# Do a complete pass on the rollout buffer # Do a complete pass on the rollout buffer
for rollout_data in self.rollout_buffer.get(self.batch_size): for rollout_data in self.rollout_buffer.get(self.batch_size):
@ -253,7 +259,16 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
pol = self.policy pol = self.policy
features = pol.extract_features(rollout_data.observations) features = pol.extract_features(rollout_data.observations)
latent_pi, latent_vf = pol.mlp_extractor(features) latent_pi, latent_vf = pol.mlp_extractor(features)
try:
p = pol._get_action_dist_from_latent(latent_pi) p = pol._get_action_dist_from_latent(latent_pi)
except ValueError:
self.policy.load_state_dict(bak)
setbackCtr += 1
print(
'[!] Gradients Exploded; reseting to last known states (setback number '+str(setbackCtr)+')')
break
del bak
bak = deepcopy(self.policy.state_dict())
p_dist = p.distribution p_dist = p.distribution
if isinstance(self.projection, WassersteinProjectionLayer): if isinstance(self.projection, WassersteinProjectionLayer):
q_dist = new_dist_like_from_sqrt( q_dist = new_dist_like_from_sqrt(