Hybrid Diag -> Full Implemented; Made spherical_chol more efficient
This commit is contained in:
parent
f184b88f19
commit
fa167b3e5f
@ -79,21 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
|||||||
for cs in allowedCovStrength:
|
for cs in allowedCovStrength:
|
||||||
if ps.value > cs.value:
|
if ps.value > cs.value:
|
||||||
continue
|
continue
|
||||||
if ps == Strength.SCALAR and cs == Strength.FULL:
|
|
||||||
# TODO: Maybe allow?
|
|
||||||
continue
|
|
||||||
if ps == Strength.DIAG and cs == Strength.FULL:
|
if ps == Strength.DIAG and cs == Strength.FULL:
|
||||||
# TODO: Implement
|
# TODO: Implement
|
||||||
continue
|
continue
|
||||||
if ps == Strength.NONE:
|
|
||||||
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
|
|
||||||
else:
|
|
||||||
for ept in allowedEPTs:
|
for ept in allowedEPTs:
|
||||||
if cs == Strength.FULL:
|
if cs == Strength.FULL:
|
||||||
for pt in allowedPTs:
|
for pt in allowedPTs:
|
||||||
|
if pt != ParametrizationType.NONE:
|
||||||
yield (ps, cs, ept, pt)
|
yield (ps, cs, ept, pt)
|
||||||
else:
|
else:
|
||||||
yield (ps, cs, ept, ProbSquashingType.NONE)
|
yield (ps, cs, ept, ParametrizationType.NONE)
|
||||||
|
|
||||||
|
|
||||||
def make_proba_distribution(
|
def make_proba_distribution(
|
||||||
@ -271,7 +266,8 @@ class CholNet(nn.Module):
|
|||||||
if self.cov_strength == Strength.NONE:
|
if self.cov_strength == Strength.NONE:
|
||||||
self.chol = th.ones(self.action_dim) * std_init
|
self.chol = th.ones(self.action_dim) * std_init
|
||||||
elif self.cov_strength == Strength.SCALAR:
|
elif self.cov_strength == Strength.SCALAR:
|
||||||
self.param = nn.Parameter(std_init, requires_grad=True)
|
self.param = nn.Parameter(
|
||||||
|
th.Tensor([std_init]), requires_grad=True)
|
||||||
elif self.cov_strength == Strength.DIAG:
|
elif self.cov_strength == Strength.DIAG:
|
||||||
self.params = nn.Parameter(
|
self.params = nn.Parameter(
|
||||||
th.ones(self.action_dim) * std_init, requires_grad=True)
|
th.ones(self.action_dim) * std_init, requires_grad=True)
|
||||||
@ -295,11 +291,16 @@ class CholNet(nn.Module):
|
|||||||
self.param = nn.Parameter(
|
self.param = nn.Parameter(
|
||||||
th.ones(self.action_dim), requires_grad=True)
|
th.ones(self.action_dim), requires_grad=True)
|
||||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
# TODO
|
self.stds = nn.Linear(latent_dim, self.action_dim)
|
||||||
pass
|
self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
|
||||||
|
# TODO: Init Non-zero?
|
||||||
|
self.params = nn.Parameter(
|
||||||
|
th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True)
|
||||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||||
# TODO
|
self.factor = nn.Linear(latent_dim, 1)
|
||||||
pass
|
# TODO: Init Off-axis differently?
|
||||||
|
self.params = nn.Parameter(
|
||||||
|
th.ones(self._full_params_len) * std_init, requires_grad=True)
|
||||||
else:
|
else:
|
||||||
raise Exception("This Exception can't happen")
|
raise Exception("This Exception can't happen")
|
||||||
|
|
||||||
@ -310,7 +311,7 @@ class CholNet(nn.Module):
|
|||||||
return self.chol
|
return self.chol
|
||||||
elif self.cov_strength == Strength.SCALAR:
|
elif self.cov_strength == Strength.SCALAR:
|
||||||
return self._ensure_positive_func(
|
return self._ensure_positive_func(
|
||||||
th.ones(self.action_dim) * self.param)
|
th.ones(self.action_dim) * self.param[0])
|
||||||
elif self.cov_strength == Strength.DIAG:
|
elif self.cov_strength == Strength.DIAG:
|
||||||
return self._ensure_positive_func(self.params)
|
return self._ensure_positive_func(self.params)
|
||||||
elif self.cov_strength == Strength.FULL:
|
elif self.cov_strength == Strength.FULL:
|
||||||
@ -328,16 +329,23 @@ class CholNet(nn.Module):
|
|||||||
return self._parameterize_full(params)
|
return self._parameterize_full(params)
|
||||||
else:
|
else:
|
||||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||||
factor = self.factor(x)
|
factor = self.factor(x)[0]
|
||||||
diag_chol = self._ensure_positive_func(
|
diag_chol = self._ensure_positive_func(
|
||||||
self.param * factor[0])
|
self.param * factor)
|
||||||
return diag_chol
|
return diag_chol
|
||||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
pass
|
# TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
|
||||||
# TODO
|
stds = self.stds(x)
|
||||||
|
smol = self._parameterize_full(self.params)
|
||||||
|
big = self.padder(smol)
|
||||||
|
pearson_cor_chol = big + th.eye(stds.shape[0])
|
||||||
|
pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
|
||||||
|
cov = stds * pearson_cor * stds
|
||||||
|
chol = th.linalg.cholesky(cov)
|
||||||
|
return chol
|
||||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||||
# TODO
|
factor = self.factor(x)
|
||||||
pass
|
return self._parameterize_full(self.params * factor[0])
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -356,14 +364,11 @@ class CholNet(nn.Module):
|
|||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
def _chol_from_flat(self, flat_chol):
|
def _chol_from_flat(self, flat_chol):
|
||||||
# chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
|
|
||||||
chol = fill_triangular(flat_chol)
|
chol = fill_triangular(flat_chol)
|
||||||
return self._ensure_diagonal_positive(chol)
|
return self._ensure_diagonal_positive(chol)
|
||||||
|
|
||||||
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
|
||||||
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
|
||||||
# sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
|
|
||||||
# self._flat_chol_len, -1, -1)
|
|
||||||
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
sphe_chol = fill_triangular(pos_flat_sphe_chol)
|
||||||
chol = self._chol_from_sphe_chol(sphe_chol)
|
chol = self._chol_from_sphe_chol(sphe_chol)
|
||||||
return chol
|
return chol
|
||||||
@ -377,17 +382,26 @@ class CholNet(nn.Module):
|
|||||||
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
|
||||||
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
# We already ensure S > 0 in _chol_from_flat_sphe_chol
|
||||||
# We ensure < pi by applying tanh*pi to all applicable elements
|
# We ensure < pi by applying tanh*pi to all applicable elements
|
||||||
|
#import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
S = sphe_chol
|
S = sphe_chol
|
||||||
n = self.action_dim
|
n = self.action_dim
|
||||||
L = th.zeros_like(sphe_chol)
|
L = th.zeros_like(sphe_chol)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(i):
|
t = 1
|
||||||
t = S[i, 1]
|
s = ''
|
||||||
for k in range(1, j+1):
|
for j in range(i+1):
|
||||||
t *= th.sin(th.tanh(S[i, k])*pi)
|
maybe_cos = 1
|
||||||
|
s_maybe_cos = ''
|
||||||
if i != j:
|
if i != j:
|
||||||
t *= th.cos(th.tanh(S[i, j+1])*pi)
|
maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
|
||||||
L[i, j] = t
|
s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||||
|
L[i, j] = S[i, 0] * t * maybe_cos
|
||||||
|
print('[L_'+str(i+1)+']_'+str(j+1) +
|
||||||
|
'=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
|
||||||
|
if j <= i and j < n-1 and i < n:
|
||||||
|
t *= th.sin(th.tanh(S[i, j+1])*pi)
|
||||||
|
s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
|
||||||
return L
|
return L
|
||||||
|
|
||||||
def _ensure_positive_func(self, x):
|
def _ensure_positive_func(self, x):
|
||||||
|
Loading…
Reference in New Issue
Block a user