Hybrid Diag -> Full Implemented; Made spherical_chol more efficient

This commit is contained in:
Dominik Moritz Roth 2022-07-16 13:05:35 +02:00
parent f184b88f19
commit fa167b3e5f

View File

@ -79,21 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
for cs in allowedCovStrength:
if ps.value > cs.value:
continue
if ps == Strength.SCALAR and cs == Strength.FULL:
# TODO: Maybe allow?
continue
if ps == Strength.DIAG and cs == Strength.FULL:
# TODO: Implement
continue
if ps == Strength.NONE:
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
else:
for ept in allowedEPTs:
if cs == Strength.FULL:
for pt in allowedPTs:
if pt != ParametrizationType.NONE:
yield (ps, cs, ept, pt)
else:
yield (ps, cs, ept, ProbSquashingType.NONE)
yield (ps, cs, ept, ParametrizationType.NONE)
def make_proba_distribution(
@ -271,7 +266,8 @@ class CholNet(nn.Module):
if self.cov_strength == Strength.NONE:
self.chol = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR:
self.param = nn.Parameter(std_init, requires_grad=True)
self.param = nn.Parameter(
th.Tensor([std_init]), requires_grad=True)
elif self.cov_strength == Strength.DIAG:
self.params = nn.Parameter(
th.ones(self.action_dim) * std_init, requires_grad=True)
@ -295,11 +291,16 @@ class CholNet(nn.Module):
self.param = nn.Parameter(
th.ones(self.action_dim), requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO
pass
self.stds = nn.Linear(latent_dim, self.action_dim)
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
# TODO: Init Non-zero?
self.params = nn.Parameter(
th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True)
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
# TODO
pass
self.factor = nn.Linear(latent_dim, 1)
# TODO: Init Off-axis differently?
self.params = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
else:
raise Exception("This Exception can't happen")
@ -310,7 +311,7 @@ class CholNet(nn.Module):
return self.chol
elif self.cov_strength == Strength.SCALAR:
return self._ensure_positive_func(
th.ones(self.action_dim) * self.param)
th.ones(self.action_dim) * self.param[0])
elif self.cov_strength == Strength.DIAG:
return self._ensure_positive_func(self.params)
elif self.cov_strength == Strength.FULL:
@ -328,16 +329,23 @@ class CholNet(nn.Module):
return self._parameterize_full(params)
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
factor = self.factor(x)
factor = self.factor(x)[0]
diag_chol = self._ensure_positive_func(
self.param * factor[0])
self.param * factor)
return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
pass
# TODO
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
stds = self.stds(x)
smol = self._parameterize_full(self.params)
big = self.padder(smol)
pearson_cor_chol = big + th.eye(stds.shape[0])
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
cov = stds * pearson_cor * stds
chol = th.linalg.cholesky(cov)
return chol
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
# TODO
pass
factor = self.factor(x)
return self._parameterize_full(self.params * factor[0])
raise Exception()
@property
@ -356,14 +364,11 @@ class CholNet(nn.Module):
raise Exception()
def _chol_from_flat(self, flat_chol):
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
chol = fill_triangular(flat_chol)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
# self._flat_chol_len, -1, -1)
sphe_chol = fill_triangular(pos_flat_sphe_chol)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
@ -377,17 +382,26 @@ class CholNet(nn.Module):
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
#import pdb
# pdb.set_trace()
S = sphe_chol
n = self.action_dim
L = th.zeros_like(sphe_chol)
for i in range(n):
for j in range(i):
t = S[i, 1]
for k in range(1, j+1):
t *= th.sin(th.tanh(S[i, k])*pi)
t = 1
s = ''
for j in range(i+1):
maybe_cos = 1
s_maybe_cos = ''
if i != j:
t *= th.cos(th.tanh(S[i, j+1])*pi)
L[i, j] = t
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
L[i, j] = S[i, 0] * t * maybe_cos
print('[L_'+str(i+1)+']_'+str(j+1) +
'=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
if j <= i and j < n-1 and i < n:
t *= th.sin(th.tanh(S[i, j+1])*pi)
s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
return L
def _ensure_positive_func(self, x):