From 9fffe048af7ebf403d97b52da4891b15d914da37 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 17 Aug 2022 22:55:42 +0200 Subject: [PATCH] Fixed Spherical_Chol not accepting batches --- .../distributions/distributions.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index b674f48..924197b 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -518,7 +518,6 @@ class CholNet(nn.Module): return chol def _chol_from_sphe_chol(self, sphe_chol): - # TODO: Test with batched data # TODO: Make efficient more # Note: # We must should ensure: @@ -527,16 +526,21 @@ class CholNet(nn.Module): # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements batch = (len(sphe_chol.shape) == 3) + batch_size = sphe_chol.shape[0] S = sphe_chol n = sphe_chol.shape[-1] L = th.zeros_like(sphe_chol) for i in range(n): #t = 1 t = th.Tensor([1])[0] + if batch: + t = t.expand((batch_size, 1)) #s = '' for j in range(i+1): #maybe_cos = 1 maybe_cos = th.Tensor([1])[0] + if batch: + maybe_cos = maybe_cos.expand((batch_size, 1)) #s_maybe_cos = '' if i != j and j < n-1 and i < n: if batch: @@ -545,14 +549,19 @@ class CholNet(nn.Module): maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' if batch: - L[:, i, j] = S[:, i, 0] * t * maybe_cos + # try: + L[:, i, j] = (S[:, i, 0] * t.T) * maybe_cos.T + # except: + # import pdb + # pdb.set_trace() else: L[i, j] = S[i, 0] * t * maybe_cos # print('[L_'+str(i+1)+']_'+str(j+1) + # '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) if j <= i and j < n-1 and i < n: if batch: - t *= th.sin(th.tanh(S[:, i, j+1])*pi) + tc = t.clone() + t = (tc.T * th.sin(th.tanh(S[:, i, j+1])*pi)).T else: t *= th.sin(th.tanh(S[i, j+1])*pi) #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'