Guarante minimum epsilon when ensuring non-zero (CholNet)
This commit is contained in:
parent
d35c3d8520
commit
64a7d5ec59
@ -212,7 +212,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
||||
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type)
|
||||
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon)
|
||||
|
||||
if self.use_sde:
|
||||
self.sample_weights(self.action_dim)
|
||||
@ -372,7 +372,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
|
||||
class CholNet(nn.Module):
|
||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType):
|
||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon):
|
||||
super().__init__()
|
||||
self.latent_dim = latent_dim
|
||||
self.action_dim = action_dim
|
||||
@ -383,6 +383,8 @@ class CholNet(nn.Module):
|
||||
self.enforce_positive_type = enforce_positive_type
|
||||
self.prob_squashing_type = prob_squashing_type
|
||||
|
||||
self.epsilon = epsilon
|
||||
|
||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||
|
||||
# Yes, this is ugly.
|
||||
@ -557,7 +559,7 @@ class CholNet(nn.Module):
|
||||
return L
|
||||
|
||||
def _ensure_positive_func(self, x):
|
||||
return self.enforce_positive_type.apply(x)
|
||||
return self.enforce_positive_type.apply(x) + self.epsilon
|
||||
|
||||
def _ensure_diagonal_positive(self, chol):
|
||||
if len(chol.shape) == 1:
|
||||
|
Loading…
Reference in New Issue
Block a user