Fixed: _chol_from_sphe_chol was unable to handle batches
This commit is contained in:
		
							parent
							
								
									c141599662
								
							
						
					
					
						commit
						046fa78206
					
				@ -388,6 +388,7 @@ class CholNet(nn.Module):
 | 
				
			|||||||
        # S[i,j] e (0, pi)   where i = 2..n, j = 2..i
 | 
					        # S[i,j] e (0, pi)   where i = 2..n, j = 2..i
 | 
				
			||||||
        # We already ensure S > 0 in _chol_from_flat_sphe_chol
 | 
					        # We already ensure S > 0 in _chol_from_flat_sphe_chol
 | 
				
			||||||
        # We ensure < pi by applying tanh*pi to all applicable elements
 | 
					        # We ensure < pi by applying tanh*pi to all applicable elements
 | 
				
			||||||
 | 
					        batch = (len(sphe_chol.shape) == 3)
 | 
				
			||||||
        S = sphe_chol
 | 
					        S = sphe_chol
 | 
				
			||||||
        n = sphe_chol.shape[-1]
 | 
					        n = sphe_chol.shape[-1]
 | 
				
			||||||
        L = th.zeros_like(sphe_chol)
 | 
					        L = th.zeros_like(sphe_chol)
 | 
				
			||||||
@ -397,13 +398,22 @@ class CholNet(nn.Module):
 | 
				
			|||||||
            for j in range(i+1):
 | 
					            for j in range(i+1):
 | 
				
			||||||
                maybe_cos = 1
 | 
					                maybe_cos = 1
 | 
				
			||||||
                #s_maybe_cos = ''
 | 
					                #s_maybe_cos = ''
 | 
				
			||||||
                if i != j:
 | 
					                if i != j and j < n-1 and i < n:
 | 
				
			||||||
 | 
					                    if batch:
 | 
				
			||||||
 | 
					                        maybe_cos = th.cos(th.tanh(S[:, i, j+1])*pi)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
                        maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
 | 
					                        maybe_cos = th.cos(th.tanh(S[i, j+1])*pi)
 | 
				
			||||||
                    s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
 | 
					                    #s_maybe_cos = 'cos([l_'+str(i+1)+']_'+str(j+2)+')'
 | 
				
			||||||
 | 
					                if batch:
 | 
				
			||||||
 | 
					                    L[:, i, j] = S[:, i, 0] * t * maybe_cos
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
                    L[i, j] = S[i, 0] * t * maybe_cos
 | 
					                    L[i, j] = S[i, 0] * t * maybe_cos
 | 
				
			||||||
                # print('[L_'+str(i+1)+']_'+str(j+1) +
 | 
					                # print('[L_'+str(i+1)+']_'+str(j+1) +
 | 
				
			||||||
                #      '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
 | 
					                #      '=[l_'+str(i+1)+']_1'+s+s_maybe_cos)
 | 
				
			||||||
                if j <= i and j < n-1 and i < n:
 | 
					                if j <= i and j < n-1 and i < n:
 | 
				
			||||||
 | 
					                    if batch:
 | 
				
			||||||
 | 
					                        t *= th.sin(th.tanh(S[:, i, j+1])*pi)
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
                        t *= th.sin(th.tanh(S[i, j+1])*pi)
 | 
					                        t *= th.sin(th.tanh(S[i, j+1])*pi)
 | 
				
			||||||
                    #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
 | 
					                    #s += 'sin([l_'+str(i+1)+']_'+str(j+2)+')'
 | 
				
			||||||
        return L
 | 
					        return L
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user