From cb9ee4f302601ac4ff9fffa642cef56aeacb5d7e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 16 Jul 2022 14:57:34 +0200 Subject: [PATCH] Fixed bugs for Hybrid[Diag=>Full] --- metastable_baselines/distributions/distributions.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index e9b4222..14f8fc2 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -288,6 +288,9 @@ class CholNet(nn.Module): self.param = nn.Parameter( th.ones(self.action_dim), requires_grad=True) elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: + if self.enforce_positive_type == EnforcePositiveType.NONE: + raise Exception( + 'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.') self.stds = nn.Linear(latent_dim, self.action_dim) self.padder = th.nn.ZeroPad2d((0, 1, 1, 0)) # TODO: Init Non-zero? @@ -332,12 +335,16 @@ class CholNet(nn.Module): return diag_chol elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: # TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form. - stds = self.stds(x) + stds = self._ensure_positive_func(self.stds(x)) smol = self._parameterize_full(self.params) big = self.padder(smol) - pearson_cor_chol = big + th.eye(stds.shape[0]) + try: + pearson_cor_chol = big + th.eye(stds.shape[-1]) + except: + import pdb + pdb.set_trace() pearson_cor = pearson_cor_chol.T @ pearson_cor_chol - cov = stds * pearson_cor * stds + cov = stds.T * pearson_cor * stds chol = th.linalg.cholesky(cov) return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: