Implemented cov parametrization via eigen-decomp
This commit is contained in:
		
							parent
							
								
									e4a8cfc349
								
							
						
					
					
						commit
						0a037deccc
					
				| @ -7,6 +7,8 @@ from torch import nn | ||||
| from torch.distributions import Normal, Independent, MultivariateNormal | ||||
| from math import pi | ||||
| 
 | ||||
| import givens | ||||
| 
 | ||||
| from stable_baselines3.common.preprocessing import get_action_dim | ||||
| 
 | ||||
| from stable_baselines3.common.distributions import sum_independent_dims | ||||
| @ -34,9 +36,8 @@ class ParametrizationType(Enum): | ||||
|     NONE = 0 | ||||
|     CHOL = 1 | ||||
|     SPHERICAL_CHOL = 2 | ||||
|     # Not (yet?) implemented: | ||||
|     # GIVENS = 3 | ||||
|     # NNLN_EIGEN = 4 | ||||
|     EIGEN = 3 | ||||
|     EIGEN_RAW = 4 | ||||
| 
 | ||||
| 
 | ||||
| class EnforcePositiveType(Enum): | ||||
| @ -382,6 +383,9 @@ class CholNet(nn.Module): | ||||
| 
 | ||||
|         self._flat_chol_len = action_dim * (action_dim + 1) // 2 | ||||
| 
 | ||||
|         self._givens_rotator = givens.Rotation(action_dim) | ||||
|         self._givens_ident = th.eye(action_dim) | ||||
| 
 | ||||
|         # Yes, this is ugly. | ||||
|         # But I don't know how this mess could be elegantly abstracted away... | ||||
| 
 | ||||
| @ -493,6 +497,10 @@ class CholNet(nn.Module): | ||||
|             return self._flat_chol_len | ||||
|         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: | ||||
|             return self._flat_chol_len | ||||
|         elif self.par_type == ParametrizationType.EIGEN: | ||||
|             return self.action_dim * 2 | ||||
|         elif self.par_type == ParametrizationType.EIGEN_BIJECT: | ||||
|             return self.action_dim * 2 | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def _parameterize_full(self, params): | ||||
| @ -500,6 +508,10 @@ class CholNet(nn.Module): | ||||
|             return self._chol_from_flat(params) | ||||
|         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: | ||||
|             return self._chol_from_flat_sphe_chol(params) | ||||
|         elif self.par_type == ParametrizationType.EIGEN: | ||||
|             return self._chol_from_givens_params(params, True) | ||||
|         elif self.par_type == ParametrizationType.EIGEN_RAW: | ||||
|             return self._chol_from_givens_params(params, False) | ||||
|         raise Exception() | ||||
| 
 | ||||
|     def _chol_from_flat(self, flat_chol): | ||||
| @ -576,6 +588,24 @@ class CholNet(nn.Module): | ||||
|         return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, | ||||
|                                                                         dim2=-1)).diag_embed() + chol.triu(1) | ||||
| 
 | ||||
|     def _chol_from_givens_params(self, params, bijection=False): | ||||
|         theta, eigenv = params[:self.action_dim], params[self.action_dim:] | ||||
| 
 | ||||
|         eigenv = self._ensure_positive_func(eigenv) | ||||
| 
 | ||||
|         if bijection: | ||||
|             eigenv = th.cumsum(eigenv, -1) | ||||
|             # reverse order, oh well... | ||||
| 
 | ||||
|         self._givens_rot.theta = theta | ||||
|         Q = self._givens_rotator(self._givens_ident) | ||||
|         Qinv = Q.transpose(dim0=-2, dim1=-1) | ||||
| 
 | ||||
|         cov = Q * th.diag(eigenv) * Qinv | ||||
|         chol = th.linalg.cholesky(cov) | ||||
| 
 | ||||
|         return chol | ||||
| 
 | ||||
|     def string(self): | ||||
|         return '<CholNet />' | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										66
									
								
								metastable_baselines/misc/givens.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								metastable_baselines/misc/givens.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | ||||
| # Source : https://github.com/diadochos/givens-torch | ||||
| # TODO: License | ||||
| 
 | ||||
| import itertools | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| 
 | ||||
| 
 | ||||
| def G_transpose(D, i, j, theta): | ||||
|     """Generate Givens rotation matrix. | ||||
|     >>> G_transpose(2, 0, 1, torch.FloatTensor([[3.1415 / 2]])) | ||||
|     tensor([[ 4.6329e-05,  1.0000e+00], | ||||
|             [-1.0000e+00,  4.6329e-05]]) | ||||
|     """ | ||||
|     R = torch.eye(D) | ||||
|     s, c = torch.sin(theta), torch.cos(theta) | ||||
|     R[i, i] = c | ||||
|     R[j, j] = c | ||||
|     R[i, j] = s | ||||
|     R[j, i] = -s | ||||
|     return R | ||||
| 
 | ||||
| 
 | ||||
| class Rotation(nn.Module): | ||||
|     def __init__(self, D): | ||||
|         """ | ||||
|         >>> # Initialized as an identity. | ||||
|         >>> A, R = torch.eye(3), Rotation(3) | ||||
|         >>> torch.all(A.eq(R(A))).item() | ||||
|         True | ||||
|         """ | ||||
|         super().__init__() | ||||
|         self.D = D | ||||
|         self.theta = torch.zeros( | ||||
|             (len(list(itertools.combinations(range(self.D), 2))), )) | ||||
| 
 | ||||
|     def forward(self, x): | ||||
|         """Apply rotation. | ||||
|         >>> A, R = torch.eye(3), Rotation(3) | ||||
|         >>> R.theta = torch.FloatTensor([3.1415 / 2, 0., 0.]) | ||||
|         >>> R(A) | ||||
|         tensor([[ 4.6329e-05,  1.0000e+00,  0.0000e+00], | ||||
|                 [-1.0000e+00,  4.6329e-05,  0.0000e+00], | ||||
|                 [ 0.0000e+00,  0.0000e+00,  1.0000e+00]]) | ||||
|         """ | ||||
|         for idx, (i, j) in enumerate(itertools.combinations(range(self.D), 2)): | ||||
|             x = torch.matmul(x, G_transpose(self.D, i, j, self.theta[idx])) | ||||
|         return x | ||||
| 
 | ||||
|     def reverse(self, x): | ||||
|         """Apply reverse rotation. | ||||
|         >>> A, R = torch.eye(3), Rotation(3) | ||||
|         >>> R.weight = torch.FloatTensor([1., 2., 3.]) | ||||
|         >>> torch.any(A.eq(R(A))).item() | ||||
|         True | ||||
|         >>> torch.all(A.eq(R.reverse(R(A)))).item() | ||||
|         True | ||||
|         """ | ||||
|         for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): | ||||
|             x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) | ||||
|         return x | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     import doctest | ||||
|     doctest.testmod() | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user