Fixed bugs for Hybrid[Diag=>Full]

This commit is contained in:
Dominik Moritz Roth 2022-07-16 14:57:34 +02:00
parent ad584d70fd
commit cb9ee4f302

View File

@ -288,6 +288,9 @@ class CholNet(nn.Module):
self.param = nn.Parameter(
th.ones(self.action_dim), requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
if self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception(
'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.')
self.stds = nn.Linear(latent_dim, self.action_dim)
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
# TODO: Init Non-zero?
@ -332,12 +335,16 @@ class CholNet(nn.Module):
return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
stds = self.stds(x)
stds = self._ensure_positive_func(self.stds(x))
smol = self._parameterize_full(self.params)
big = self.padder(smol)
pearson_cor_chol = big + th.eye(stds.shape[0])
try:
pearson_cor_chol = big + th.eye(stds.shape[-1])
except:
import pdb
pdb.set_trace()
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
cov = stds * pearson_cor * stds
cov = stds.T * pearson_cor * stds
chol = th.linalg.cholesky(cov)
return chol
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: