diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index e9b4222..14f8fc2 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -288,6 +288,9 @@ class CholNet(nn.Module): self.param = nn.Parameter( th.ones(self.action_dim), requires_grad=True) elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: + if self.enforce_positive_type == EnforcePositiveType.NONE: + raise Exception( + 'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.') self.stds = nn.Linear(latent_dim, self.action_dim) self.padder = th.nn.ZeroPad2d((0, 1, 1, 0)) # TODO: Init Non-zero? @@ -332,12 +335,16 @@ class CholNet(nn.Module): return diag_chol elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: # TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form. - stds = self.stds(x) + stds = self._ensure_positive_func(self.stds(x)) smol = self._parameterize_full(self.params) big = self.padder(smol) - pearson_cor_chol = big + th.eye(stds.shape[0]) + try: + pearson_cor_chol = big + th.eye(stds.shape[-1]) + except: + import pdb + pdb.set_trace() pearson_cor = pearson_cor_chol.T @ pearson_cor_chol - cov = stds * pearson_cor * stds + cov = stds.T * pearson_cor * stds chol = th.linalg.cholesky(cov) return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: