From 64a7d5ec59a5ad56226cad012b827270a2a1fd6a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 16 Aug 2022 20:02:33 +0200 Subject: [PATCH] Guarante minimum epsilon when ensuring non-zero (CholNet) --- metastable_baselines/distributions/distributions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 8f9ce4b..5d86047 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -212,7 +212,7 @@ class UniversalGaussianDistribution(SB3_Distribution): mean_actions = nn.Linear(latent_dim, self.action_dim) chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength, - self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type) + self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type, self.epsilon) if self.use_sde: self.sample_weights(self.action_dim) @@ -372,7 +372,7 @@ class UniversalGaussianDistribution(SB3_Distribution): class CholNet(nn.Module): - def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType): + def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType, epsilon): super().__init__() self.latent_dim = latent_dim self.action_dim = action_dim @@ -383,6 +383,8 @@ class CholNet(nn.Module): self.enforce_positive_type = enforce_positive_type self.prob_squashing_type = prob_squashing_type + self.epsilon = epsilon + self._flat_chol_len = action_dim * (action_dim + 1) // 2 # Yes, this is ugly. @@ -557,7 +559,7 @@ class CholNet(nn.Module): return L def _ensure_positive_func(self, x): - return self.enforce_positive_type.apply(x) + return self.enforce_positive_type.apply(x) + self.epsilon def _ensure_diagonal_positive(self, chol): if len(chol.shape) == 1: