Guarante minimum epsilon when ensuring non-zero (CholNet)
This commit is contained in:
parent
d35c3d8520
commit
64a7d5ec59
@ -212,7 +212,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||||
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
|
||||||
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type)
|
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon)
|
||||||
|
|
||||||
if self.use_sde:
|
if self.use_sde:
|
||||||
self.sample_weights(self.action_dim)
|
self.sample_weights(self.action_dim)
|
||||||
@ -372,7 +372,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
|
|
||||||
class CholNet(nn.Module):
|
class CholNet(nn.Module):
|
||||||
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType):
|
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.latent_dim = latent_dim
|
self.latent_dim = latent_dim
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
@ -383,6 +383,8 @@ class CholNet(nn.Module):
|
|||||||
self.enforce_positive_type = enforce_positive_type
|
self.enforce_positive_type = enforce_positive_type
|
||||||
self.prob_squashing_type = prob_squashing_type
|
self.prob_squashing_type = prob_squashing_type
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||||
|
|
||||||
# Yes, this is ugly.
|
# Yes, this is ugly.
|
||||||
@ -557,7 +559,7 @@ class CholNet(nn.Module):
|
|||||||
return L
|
return L
|
||||||
|
|
||||||
def _ensure_positive_func(self, x):
|
def _ensure_positive_func(self, x):
|
||||||
return self.enforce_positive_type.apply(x)
|
return self.enforce_positive_type.apply(x) + self.epsilon
|
||||||
|
|
||||||
def _ensure_diagonal_positive(self, chol):
|
def _ensure_diagonal_positive(self, chol):
|
||||||
if len(chol.shape) == 1:
|
if len(chol.shape) == 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user