From fa167b3e5fb12debfb2e133c43ed3908ff2a6390 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 16 Jul 2022 13:05:35 +0200 Subject: [PATCH] Hybrid Diag -> Full Implemented; Made spherical_chol more efficient --- .../distributions/distributions.py | 78 +++++++++++-------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index ed282d6..c43423b 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -79,21 +79,16 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng for cs in allowedCovStrength: if ps.value > cs.value: continue - if ps == Strength.SCALAR and cs == Strength.FULL: - # TODO: Maybe allow? - continue if ps == Strength.DIAG and cs == Strength.FULL: # TODO: Implement continue - if ps == Strength.NONE: - yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE) - else: - for ept in allowedEPTs: - if cs == Strength.FULL: - for pt in allowedPTs: + for ept in allowedEPTs: + if cs == Strength.FULL: + for pt in allowedPTs: + if pt != ParametrizationType.NONE: yield (ps, cs, ept, pt) - else: - yield (ps, cs, ept, ProbSquashingType.NONE) + else: + yield (ps, cs, ept, ParametrizationType.NONE) def make_proba_distribution( @@ -271,7 +266,8 @@ class CholNet(nn.Module): if self.cov_strength == Strength.NONE: self.chol = th.ones(self.action_dim) * std_init elif self.cov_strength == Strength.SCALAR: - self.param = nn.Parameter(std_init, requires_grad=True) + self.param = nn.Parameter( + th.Tensor([std_init]), requires_grad=True) elif self.cov_strength == Strength.DIAG: self.params = nn.Parameter( th.ones(self.action_dim) * std_init, requires_grad=True) @@ -295,11 +291,16 @@ class CholNet(nn.Module): self.param = nn.Parameter( th.ones(self.action_dim), requires_grad=True) elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: - # TODO - pass + self.stds = nn.Linear(latent_dim, self.action_dim) + self.padder = th.nn.ZeroPad2d((0, 1, 1, 0)) + # TODO: Init Non-zero? + self.params = nn.Parameter( + th.ones(self._full_params_len - self.action_dim) * 0, requires_grad=True) elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: - # TODO - pass + self.factor = nn.Linear(latent_dim, 1) + # TODO: Init Off-axis differently? + self.params = nn.Parameter( + th.ones(self._full_params_len) * std_init, requires_grad=True) else: raise Exception("This Exception can't happen") @@ -310,7 +311,7 @@ class CholNet(nn.Module): return self.chol elif self.cov_strength == Strength.SCALAR: return self._ensure_positive_func( - th.ones(self.action_dim) * self.param) + th.ones(self.action_dim) * self.param[0]) elif self.cov_strength == Strength.DIAG: return self._ensure_positive_func(self.params) elif self.cov_strength == Strength.FULL: @@ -328,16 +329,23 @@ class CholNet(nn.Module): return self._parameterize_full(params) else: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: - factor = self.factor(x) + factor = self.factor(x)[0] diag_chol = self._ensure_positive_func( - self.param * factor[0]) + self.param * factor) return diag_chol elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: - pass - # TODO + # TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form. + stds = self.stds(x) + smol = self._parameterize_full(self.params) + big = self.padder(smol) + pearson_cor_chol = big + th.eye(stds.shape[0]) + pearson_cor = pearson_cor_chol.T @ pearson_cor_chol + cov = stds * pearson_cor * stds + chol = th.linalg.cholesky(cov) + return chol elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: - # TODO - pass + factor = self.factor(x) + return self._parameterize_full(self.params * factor[0]) raise Exception() @property @@ -356,14 +364,11 @@ class CholNet(nn.Module): raise Exception() def _chol_from_flat(self, flat_chol): - # chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) chol = fill_triangular(flat_chol) return self._ensure_diagonal_positive(chol) def _chol_from_flat_sphe_chol(self, flat_sphe_chol): pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) - # sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( - # self._flat_chol_len, -1, -1) sphe_chol = fill_triangular(pos_flat_sphe_chol) chol = self._chol_from_sphe_chol(sphe_chol) return chol @@ -377,17 +382,26 @@ class CholNet(nn.Module): # S[i,j] e (0, pi) where i = 2..n, j = 2..i # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements + #import pdb + # pdb.set_trace() S = sphe_chol n = self.action_dim L = th.zeros_like(sphe_chol) for i in range(n): - for j in range(i): - t = S[i, 1] - for k in range(1, j+1): - t *= th.sin(th.tanh(S[i, k])*pi) + t = 1 + s = '' + for j in range(i+1): + maybe_cos = 1 + s_maybe_cos = '' if i != j: - t *= th.cos(th.tanh(S[i, j+1])*pi) - L[i, j] = t + maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) + s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' + L[i, j] = S[i, 0] * t * maybe_cos + print('[L_'+str(i+1)+']_'+str(j+1) + + '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) + if j <= i and j < n-1 and i < n: + t *= th.sin(th.tanh(S[i, j+1])*pi) + s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')' return L def _ensure_positive_func(self, x):