Fix bug when self.max_grad_norm=None

This commit is contained in:
Dominik Moritz Roth 2024-01-23 13:32:38 +01:00
parent c67f78159b
commit 8bbb01504b

View File

@ -92,7 +92,7 @@ class TRPL(BetterOnPolicyAlgorithm):
normalize_advantage: bool = True, normalize_advantage: bool = True,
ent_coef: float = 0.0, ent_coef: float = 0.0,
vf_coef: float = 0.5, vf_coef: float = 0.5,
max_grad_norm: float = 0.5, max_grad_norm: Union[float, None] = None,
use_sde: bool = False, use_sde: bool = False,
sde_sample_freq: int = -1, sde_sample_freq: int = -1,
use_pca: bool = False, use_pca: bool = False,
@ -306,6 +306,7 @@ class TRPL(BetterOnPolicyAlgorithm):
self.policy.optimizer.zero_grad() self.policy.optimizer.zero_grad()
loss.backward() loss.backward()
# Clip grad norm # Clip grad norm
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step() self.policy.optimizer.step()