Implemented cov parametrization via eigen-decomp
This commit is contained in:
		
							parent
							
								
									e4a8cfc349
								
							
						
					
					
						commit
						0a037deccc
					
				| @ -7,6 +7,8 @@ from torch import nn | |||||||
| from torch.distributions import Normal, Independent, MultivariateNormal | from torch.distributions import Normal, Independent, MultivariateNormal | ||||||
| from math import pi | from math import pi | ||||||
| 
 | 
 | ||||||
|  | import givens | ||||||
|  | 
 | ||||||
| from stable_baselines3.common.preprocessing import get_action_dim | from stable_baselines3.common.preprocessing import get_action_dim | ||||||
| 
 | 
 | ||||||
| from stable_baselines3.common.distributions import sum_independent_dims | from stable_baselines3.common.distributions import sum_independent_dims | ||||||
| @ -34,9 +36,8 @@ class ParametrizationType(Enum): | |||||||
|     NONE = 0 |     NONE = 0 | ||||||
|     CHOL = 1 |     CHOL = 1 | ||||||
|     SPHERICAL_CHOL = 2 |     SPHERICAL_CHOL = 2 | ||||||
|     # Not (yet?) implemented: |     EIGEN = 3 | ||||||
|     # GIVENS = 3 |     EIGEN_RAW = 4 | ||||||
|     # NNLN_EIGEN = 4 |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class EnforcePositiveType(Enum): | class EnforcePositiveType(Enum): | ||||||
| @ -382,6 +383,9 @@ class CholNet(nn.Module): | |||||||
| 
 | 
 | ||||||
|         self._flat_chol_len = action_dim * (action_dim + 1) // 2 |         self._flat_chol_len = action_dim * (action_dim + 1) // 2 | ||||||
| 
 | 
 | ||||||
|  |         self._givens_rotator = givens.Rotation(action_dim) | ||||||
|  |         self._givens_ident = th.eye(action_dim) | ||||||
|  | 
 | ||||||
|         # Yes, this is ugly. |         # Yes, this is ugly. | ||||||
|         # But I don't know how this mess could be elegantly abstracted away... |         # But I don't know how this mess could be elegantly abstracted away... | ||||||
| 
 | 
 | ||||||
| @ -493,6 +497,10 @@ class CholNet(nn.Module): | |||||||
|             return self._flat_chol_len |             return self._flat_chol_len | ||||||
|         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: |         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: | ||||||
|             return self._flat_chol_len |             return self._flat_chol_len | ||||||
|  |         elif self.par_type == ParametrizationType.EIGEN: | ||||||
|  |             return self.action_dim * 2 | ||||||
|  |         elif self.par_type == ParametrizationType.EIGEN_BIJECT: | ||||||
|  |             return self.action_dim * 2 | ||||||
|         raise Exception() |         raise Exception() | ||||||
| 
 | 
 | ||||||
|     def _parameterize_full(self, params): |     def _parameterize_full(self, params): | ||||||
| @ -500,6 +508,10 @@ class CholNet(nn.Module): | |||||||
|             return self._chol_from_flat(params) |             return self._chol_from_flat(params) | ||||||
|         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: |         elif self.par_type == ParametrizationType.SPHERICAL_CHOL: | ||||||
|             return self._chol_from_flat_sphe_chol(params) |             return self._chol_from_flat_sphe_chol(params) | ||||||
|  |         elif self.par_type == ParametrizationType.EIGEN: | ||||||
|  |             return self._chol_from_givens_params(params, True) | ||||||
|  |         elif self.par_type == ParametrizationType.EIGEN_RAW: | ||||||
|  |             return self._chol_from_givens_params(params, False) | ||||||
|         raise Exception() |         raise Exception() | ||||||
| 
 | 
 | ||||||
|     def _chol_from_flat(self, flat_chol): |     def _chol_from_flat(self, flat_chol): | ||||||
| @ -576,6 +588,24 @@ class CholNet(nn.Module): | |||||||
|         return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, |         return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, | ||||||
|                                                                         dim2=-1)).diag_embed() + chol.triu(1) |                                                                         dim2=-1)).diag_embed() + chol.triu(1) | ||||||
| 
 | 
 | ||||||
|  |     def _chol_from_givens_params(self, params, bijection=False): | ||||||
|  |         theta, eigenv = params[:self.action_dim], params[self.action_dim:] | ||||||
|  | 
 | ||||||
|  |         eigenv = self._ensure_positive_func(eigenv) | ||||||
|  | 
 | ||||||
|  |         if bijection: | ||||||
|  |             eigenv = th.cumsum(eigenv, -1) | ||||||
|  |             # reverse order, oh well... | ||||||
|  | 
 | ||||||
|  |         self._givens_rot.theta = theta | ||||||
|  |         Q = self._givens_rotator(self._givens_ident) | ||||||
|  |         Qinv = Q.transpose(dim0=-2, dim1=-1) | ||||||
|  | 
 | ||||||
|  |         cov = Q * th.diag(eigenv) * Qinv | ||||||
|  |         chol = th.linalg.cholesky(cov) | ||||||
|  | 
 | ||||||
|  |         return chol | ||||||
|  | 
 | ||||||
|     def string(self): |     def string(self): | ||||||
|         return '<CholNet />' |         return '<CholNet />' | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										66
									
								
								metastable_baselines/misc/givens.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								metastable_baselines/misc/givens.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | |||||||
|  | # Source : https://github.com/diadochos/givens-torch | ||||||
|  | # TODO: License | ||||||
|  | 
 | ||||||
|  | import itertools | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def G_transpose(D, i, j, theta): | ||||||
|  |     """Generate Givens rotation matrix. | ||||||
|  |     >>> G_transpose(2, 0, 1, torch.FloatTensor([[3.1415 / 2]])) | ||||||
|  |     tensor([[ 4.6329e-05,  1.0000e+00], | ||||||
|  |             [-1.0000e+00,  4.6329e-05]]) | ||||||
|  |     """ | ||||||
|  |     R = torch.eye(D) | ||||||
|  |     s, c = torch.sin(theta), torch.cos(theta) | ||||||
|  |     R[i, i] = c | ||||||
|  |     R[j, j] = c | ||||||
|  |     R[i, j] = s | ||||||
|  |     R[j, i] = -s | ||||||
|  |     return R | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Rotation(nn.Module): | ||||||
|  |     def __init__(self, D): | ||||||
|  |         """ | ||||||
|  |         >>> # Initialized as an identity. | ||||||
|  |         >>> A, R = torch.eye(3), Rotation(3) | ||||||
|  |         >>> torch.all(A.eq(R(A))).item() | ||||||
|  |         True | ||||||
|  |         """ | ||||||
|  |         super().__init__() | ||||||
|  |         self.D = D | ||||||
|  |         self.theta = torch.zeros( | ||||||
|  |             (len(list(itertools.combinations(range(self.D), 2))), )) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         """Apply rotation. | ||||||
|  |         >>> A, R = torch.eye(3), Rotation(3) | ||||||
|  |         >>> R.theta = torch.FloatTensor([3.1415 / 2, 0., 0.]) | ||||||
|  |         >>> R(A) | ||||||
|  |         tensor([[ 4.6329e-05,  1.0000e+00,  0.0000e+00], | ||||||
|  |                 [-1.0000e+00,  4.6329e-05,  0.0000e+00], | ||||||
|  |                 [ 0.0000e+00,  0.0000e+00,  1.0000e+00]]) | ||||||
|  |         """ | ||||||
|  |         for idx, (i, j) in enumerate(itertools.combinations(range(self.D), 2)): | ||||||
|  |             x = torch.matmul(x, G_transpose(self.D, i, j, self.theta[idx])) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def reverse(self, x): | ||||||
|  |         """Apply reverse rotation. | ||||||
|  |         >>> A, R = torch.eye(3), Rotation(3) | ||||||
|  |         >>> R.weight = torch.FloatTensor([1., 2., 3.]) | ||||||
|  |         >>> torch.any(A.eq(R(A))).item() | ||||||
|  |         True | ||||||
|  |         >>> torch.all(A.eq(R.reverse(R(A)))).item() | ||||||
|  |         True | ||||||
|  |         """ | ||||||
|  |         for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))): | ||||||
|  |             x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx])) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     import doctest | ||||||
|  |     doctest.testmod() | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user