Guarante minimum epsilon when ensuring non-zero (CholNet)

This commit is contained in:
Dominik Moritz Roth 2022-08-16 20:02:33 +02:00
parent d35c3d8520
commit 64a7d5ec59

View File

@ -212,7 +212,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
mean_actions = nn.Linear(latent_dim, self.action_dim) mean_actions = nn.Linear(latent_dim, self.action_dim)
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength, chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type) self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon)
if self.use_sde: if self.use_sde:
self.sample_weights(self.action_dim) self.sample_weights(self.action_dim)
@ -372,7 +372,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
class CholNet(nn.Module): class CholNet(nn.Module):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType): def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon):
super().__init__() super().__init__()
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.action_dim = action_dim self.action_dim = action_dim
@ -383,6 +383,8 @@ class CholNet(nn.Module):
self.enforce_positive_type = enforce_positive_type self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = prob_squashing_type self.prob_squashing_type = prob_squashing_type
self.epsilon = epsilon
self._flat_chol_len = action_dim * (action_dim + 1) // 2 self._flat_chol_len = action_dim * (action_dim + 1) // 2
# Yes, this is ugly. # Yes, this is ugly.
@ -557,7 +559,7 @@ class CholNet(nn.Module):
return L return L
def _ensure_positive_func(self, x): def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x) return self.enforce_positive_type.apply(x) + self.epsilon
def _ensure_diagonal_positive(self, chol): def _ensure_diagonal_positive(self, chol):
if len(chol.shape) == 1: if len(chol.shape) == 1: