Hybrid Diag -> Full Implemented; Made spherical_chol more efficient
This commit is contained in:
parent
f184b88f19
commit
fa167b3e5f
@ -79,21 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
||||
for cs in allowedCovStrength:
|
||||
if ps.value > cs.value:
|
||||
continue
|
||||
if ps == Strength.SCALAR and cs == Strength.FULL:
|
||||
# TODO: Maybe allow?
|
||||
continue
|
||||
if ps == Strength.DIAG and cs == Strength.FULL:
|
||||
# TODO: Implement
|
||||
continue
|
||||
if ps == Strength.NONE:
|
||||
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
|
||||
else:
|
||||
for ept in allowedEPTs:
|
||||
if cs == Strength.FULL:
|
||||
for pt in allowedPTs:
|
||||
for ept in allowedEPTs:
|
||||
if cs == Strength.FULL:
|
||||
for pt in allowedPTs:
|
||||
if pt != ParametrizationType.NONE:
|
||||
yield (ps, cs, ept, pt)
|
||||
else:
|
||||
yield (ps, cs, ept, ProbSquashingType.NONE)
|
||||
else:
|
||||
yield (ps, cs, ept, ParametrizationType.NONE)
|
||||
|
||||
|
||||
def make_proba_distribution(
|
||||
@ -271,7 +266,8 @@ class CholNet(nn.Module):
|
||||
if self.cov_strength == Strength.NONE:
|
||||
self.chol = th.ones(self.action_dim) * std_init
|
||||
elif self.cov_strength == Strength.SCALAR:
|
||||
self.param = nn.Parameter(std_init, requires_grad=True)
|
||||
self.param = nn.Parameter(
|
||||
th.Tensor([std_init]), requires_grad=True)
|
||||
elif self.cov_strength == Strength.DIAG:
|
||||
self.params = nn.Parameter(
|
||||
th.ones(self.action_dim) * std_init, requires_grad=True)
|
||||
@ -295,11 +291,16 @@ class CholNet(nn.Module):
|
||||
self.param = nn.Parameter(
|
||||
th.ones(self.action_dim), requires_grad=True)
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
# TODO
|
||||
pass
|
||||
self.stds = nn.Linear(latent_dim, self.action_dim)
|
||||
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
|
||||
# TODO: Init Non-zero?
|
||||
self.params = nn.Parameter(
|
||||
th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True)
|
||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||
# TODO
|
||||
pass
|
||||
self.factor = nn.Linear(latent_dim, 1)
|
||||
# TODO: Init Off-axis differently?
|
||||
self.params = nn.Parameter(
|
||||
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
||||
else:
|
||||
raise Exception("This Exception can't happen")
|
||||
|
||||
@ -310,7 +311,7 @@ class CholNet(nn.Module):
|
||||
return self.chol
|
||||
elif self.cov_strength == Strength.SCALAR:
|
||||
return self._ensure_positive_func(
|
||||
th.ones(self.action_dim) * self.param)
|
||||
th.ones(self.action_dim) * self.param[0])
|
||||
elif self.cov_strength == Strength.DIAG:
|
||||
return self._ensure_positive_func(self.params)
|
||||
elif self.cov_strength == Strength.FULL:
|
||||
@ -328,16 +329,23 @@ class CholNet(nn.Module):
|
||||
return self._parameterize_full(params)
|
||||
else:
|
||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||
factor = self.factor(x)
|
||||
factor = self.factor(x)[0]
|
||||
diag_chol = self._ensure_positive_func(
|
||||
self.param * factor[0])
|
||||
self.param * factor)
|
||||
return diag_chol
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
pass
|
||||
# TODO
|
||||
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
|
||||
stds = self.stds(x)
|
||||
smol = self._parameterize_full(self.params)
|
||||
big = self.padder(smol)
|
||||
pearson_cor_chol = big + th.eye(stds.shape[0])
|
||||
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
|
||||
cov = stds * pearson_cor * stds
|
||||
chol = th.linalg.cholesky(cov)
|
||||
return chol
|
||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||
# TODO
|
||||
pass
|
||||
factor = self.factor(x)
|
||||
return self._parameterize_full(self.params * factor[0])
|
||||
raise Exception()
|
||||
|
||||
@property
|
||||
@ -356,14 +364,11 @@ class CholNet(nn.Module):
|
||||
raise Exception()
|
||||
|
||||
def _chol_from_flat(self, flat_chol):
|
||||
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
||||
chol = fill_triangular(flat_chol)
|
||||
return self._ensure_diagonal_positive(chol)
|
||||
|
||||
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
||||
# self._flat_chol_len, -1, -1)
|
||||
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
||||
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||
return chol
|
||||
@ -377,17 +382,26 @@ class CholNet(nn.Module):
|
||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||
#import pdb
|
||||
# pdb.set_trace()
|
||||
S = sphe_chol
|
||||
n = self.action_dim
|
||||
L = th.zeros_like(sphe_chol)
|
||||
for i in range(n):
|
||||
for j in range(i):
|
||||
t = S[i, 1]
|
||||
for k in range(1, j+1):
|
||||
t *= th.sin(th.tanh(S[i, k])*pi)
|
||||
t = 1
|
||||
s = ''
|
||||
for j in range(i+1):
|
||||
maybe_cos = 1
|
||||
s_maybe_cos = ''
|
||||
if i != j:
|
||||
t *= th.cos(th.tanh(S[i, j+1])*pi)
|
||||
L[i, j] = t
|
||||
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||
s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
L[i, j] = S[i, 0] * t * maybe_cos
|
||||
print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||
'=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
||||
if j <= i and j < n-1 and i < n:
|
||||
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
||||
s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||
return L
|
||||
|
||||
def _ensure_positive_func(self, x):
|
||||
|
Loading…
Reference in New Issue
Block a user