# Source : https://github.com/diadochos/givens-torch # TODO: License import itertools import torch import torch.nn as nn def G_transpose(D, i, j, theta): """Generate Givens rotation matrix. >>> G_transpose(2, 0, 1, torch.FloatTensor([[3.1415 / 2]])) tensor([[ 4.6329e-05, 1.0000e+00], [-1.0000e+00, 4.6329e-05]]) """ R = torch.eye(D) s, c = torch.sin(theta), torch.cos(theta) R[i, i] = c R[j, j] = c R[i, j] = s R[j, i] = -s return R class Rotation(nn.Module): def __init__(self, D): """ >>> # Initialized as an identity. >>> A, R = torch.eye(3), Rotation(3) >>> torch.all(A.eq(R(A))).item() True """ super().__init__() self.D = D self.theta = torch.zeros( (len(list(itertools.combinations(range(self.D), 2))), )) def forward(self, x): """Apply rotation. >>> A, R = torch.eye(3), Rotation(3) >>> R.theta = torch.FloatTensor([3.1415 / 2, 0., 0.]) >>> R(A) tensor([[ 4.6329e-05, 1.0000e+00, 0.0000e+00], [-1.0000e+00, 4.6329e-05, 0.0000e+00], [ 0.0000e+00, 0.0000e+00, 1.0000e+00]]) """ for idx, (i, j) in enumerate(itertools.combinations(range(self.D), 2)): x = torch.matmul(x, G_transpose(self.D, i, j, self.theta[idx])) return x def reverse(self, x): """Apply reverse rotation. >>> A, R = torch.eye(3), Rotation(3) >>> R.weight = torch.FloatTensor([1., 2., 3.]) >>> torch.any(A.eq(R(A))).item() True >>> torch.all(A.eq(R.reverse(R(A)))).item() True """ for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) return x