Hotfix for exploding gradients
This commit is contained in:
parent
82a174122a
commit
479d73ac4b
@ -25,6 +25,8 @@ from metastable_projections.projections.kl_projection_layer import KLProjectionL
|
|||||||
|
|
||||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
@ -228,10 +230,14 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
pg_losses, value_losses = [], []
|
pg_losses, value_losses = [], []
|
||||||
clip_fractions = []
|
clip_fractions = []
|
||||||
|
|
||||||
|
setbackCtr = 0
|
||||||
|
bak = deepcopy(self.policy.state_dict())
|
||||||
|
|
||||||
continue_training = True
|
continue_training = True
|
||||||
|
|
||||||
# train for n_epochs epochs
|
# train for n_epochs epochs
|
||||||
for epoch in range(self.n_epochs):
|
for epoch in range(self.n_epochs):
|
||||||
|
# self.policy.load_state_dict(
|
||||||
approx_kl_divs = []
|
approx_kl_divs = []
|
||||||
# Do a complete pass on the rollout buffer
|
# Do a complete pass on the rollout buffer
|
||||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||||
@ -253,7 +259,16 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
pol = self.policy
|
pol = self.policy
|
||||||
features = pol.extract_features(rollout_data.observations)
|
features = pol.extract_features(rollout_data.observations)
|
||||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
|
try:
|
||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
|
except ValueError:
|
||||||
|
self.policy.load_state_dict(bak)
|
||||||
|
setbackCtr += 1
|
||||||
|
print(
|
||||||
|
'[!] Gradients Exploded; reseting to last known states (setback number '+str(setbackCtr)+')')
|
||||||
|
break
|
||||||
|
del bak
|
||||||
|
bak = deepcopy(self.policy.state_dict())
|
||||||
p_dist = p.distribution
|
p_dist = p.distribution
|
||||||
if isinstance(self.projection, WassersteinProjectionLayer):
|
if isinstance(self.projection, WassersteinProjectionLayer):
|
||||||
q_dist = new_dist_like_from_sqrt(
|
q_dist = new_dist_like_from_sqrt(
|
||||||
|
Loading…
Reference in New Issue
Block a user