Hotfix for exploding gradients
This commit is contained in:
parent
82a174122a
commit
479d73ac4b
@ -25,6 +25,8 @@ from metastable_projections.projections.kl_projection_layer import KLProjectionL
|
||||
|
||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
"""
|
||||
@ -228,10 +230,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
setbackCtr = 0
|
||||
bak = deepcopy(self.policy.state_dict())
|
||||
|
||||
continue_training = True
|
||||
|
||||
# train for n_epochs epochs
|
||||
for epoch in range(self.n_epochs):
|
||||
# self.policy.load_state_dict(
|
||||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
@ -253,7 +259,16 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
pol = self.policy
|
||||
features = pol.extract_features(rollout_data.observations)
|
||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||
try:
|
||||
p = pol._get_action_dist_from_latent(latent_pi)
|
||||
except ValueError:
|
||||
self.policy.load_state_dict(bak)
|
||||
setbackCtr += 1
|
||||
print(
|
||||
'[!] Gradients Exploded; reseting to last known states (setback number '+str(setbackCtr)+')')
|
||||
break
|
||||
del bak
|
||||
bak = deepcopy(self.policy.state_dict())
|
||||
p_dist = p.distribution
|
||||
if isinstance(self.projection, WassersteinProjectionLayer):
|
||||
q_dist = new_dist_like_from_sqrt(
|
||||
|
Loading…
Reference in New Issue
Block a user