From 046fa78206740b09cb1b2859d5f6629f2bd6d169 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 16 Jul 2022 17:34:25 +0200 Subject: [PATCH] Fixed: _chol_from_sphe_chol was unable to handle batches --- .../distributions/distributions.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 4f0373f..1fd530e 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -388,6 +388,7 @@ class CholNet(nn.Module): # S[i,j] e (0, pi) where i = 2..n, j = 2..i # We already ensure S > 0 in _chol_from_flat_sphe_chol # We ensure < pi by applying tanh*pi to all applicable elements + batch = (len(sphe_chol.shape) == 3) S = sphe_chol n = sphe_chol.shape[-1] L = th.zeros_like(sphe_chol) @@ -397,14 +398,23 @@ class CholNet(nn.Module): for j in range(i+1): maybe_cos = 1 #s_maybe_cos = '' - if i != j: - maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) - s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' - L[i, j] = S[i, 0] * t * maybe_cos + if i != j and j < n-1 and i < n: + if batch: + maybe_cos = th.cos(th.tanh(S[:, i, j+1])*pi) + else: + maybe_cos = th.cos(th.tanh(S[i, j+1])*pi) + #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')' + if batch: + L[:, i, j] = S[:, i, 0] * t * maybe_cos + else: + L[i, j] = S[i, 0] * t * maybe_cos # print('[L_'+str(i+1)+']_'+str(j+1) + # '=[l_'+str(i+1)+']_1'+s+s_maybe_cos) if j <= i and j < n-1 and i < n: - t *= th.sin(th.tanh(S[i, j+1])*pi) + if batch: + t *= th.sin(th.tanh(S[:, i, j+1])*pi) + else: + t *= th.sin(th.tanh(S[i, j+1])*pi) #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')' return L