diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index ccacbef..e9b4222 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -379,26 +379,24 @@ class CholNet(nn.Module): # S[i,j] e (0, pi) where i = 2..n, j = 2..i # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements - #import pdb - # pdb.set_trace() S = sphe_chol n = self.action_dim L = th.zeros_like(sphe_chol) for i in range(n): t = 1 - s = '' + #s = '' for j in range(i+1): maybe_cos = 1 - s_maybe_cos = '' + #s_maybe_cos = '' if i != j: maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' L[i, j] = S[i, 0] * t * maybe_cos - print('[L_'+str(i+1)+']_'+str(j+1) + - '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) + # print('[L_'+str(i+1)+']_'+str(j+1) + + # '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) if j <= i and j < n-1 and i < n: t *= th.sin(th.tanh(S[i, j+1])*pi) - s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')' + #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')' return L def _ensure_positive_func(self, x):