Fixed bugs for Hybrid[Diag=>Full]
This commit is contained in:
parent
ad584d70fd
commit
cb9ee4f302
@ -288,6 +288,9 @@ class CholNet(nn.Module):
|
||||
self.param = nn.Parameter(
|
||||
th.ones(self.action_dim), requires_grad=True)
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
if self.enforce_positive_type == EnforcePositiveType.NONE:
|
||||
raise Exception(
|
||||
'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.')
|
||||
self.stds = nn.Linear(latent_dim, self.action_dim)
|
||||
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
|
||||
# TODO: Init Non-zero?
|
||||
@ -332,12 +335,16 @@ class CholNet(nn.Module):
|
||||
return diag_chol
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
|
||||
stds = self.stds(x)
|
||||
stds = self._ensure_positive_func(self.stds(x))
|
||||
smol = self._parameterize_full(self.params)
|
||||
big = self.padder(smol)
|
||||
pearson_cor_chol = big + th.eye(stds.shape[0])
|
||||
try:
|
||||
pearson_cor_chol = big + th.eye(stds.shape[-1])
|
||||
except:
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
|
||||
cov = stds * pearson_cor * stds
|
||||
cov = stds.T * pearson_cor * stds
|
||||
chol = th.linalg.cholesky(cov)
|
||||
return chol
|
||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||
|
Loading…
Reference in New Issue
Block a user