Implemented cov parametrization via eigen-decomp
This commit is contained in:
parent
e4a8cfc349
commit
0a037deccc
@ -7,6 +7,8 @@ from torch import nn
|
|||||||
from torch.distributions import Normal, Independent, MultivariateNormal
|
from torch.distributions import Normal, Independent, MultivariateNormal
|
||||||
from math import pi
|
from math import pi
|
||||||
|
|
||||||
|
import givens
|
||||||
|
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
|
|
||||||
from stable_baselines3.common.distributions import sum_independent_dims
|
from stable_baselines3.common.distributions import sum_independent_dims
|
||||||
@ -34,9 +36,8 @@ class ParametrizationType(Enum):
|
|||||||
NONE = 0
|
NONE = 0
|
||||||
CHOL = 1
|
CHOL = 1
|
||||||
SPHERICAL_CHOL = 2
|
SPHERICAL_CHOL = 2
|
||||||
# Not (yet?) implemented:
|
EIGEN = 3
|
||||||
# GIVENS = 3
|
EIGEN_RAW = 4
|
||||||
# NNLN_EIGEN = 4
|
|
||||||
|
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
@ -382,6 +383,9 @@ class CholNet(nn.Module):
|
|||||||
|
|
||||||
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
self._flat_chol_len = action_dim * (action_dim + 1) // 2
|
||||||
|
|
||||||
|
self._givens_rotator = givens.Rotation(action_dim)
|
||||||
|
self._givens_ident = th.eye(action_dim)
|
||||||
|
|
||||||
# Yes, this is ugly.
|
# Yes, this is ugly.
|
||||||
# But I don't know how this mess could be elegantly abstracted away...
|
# But I don't know how this mess could be elegantly abstracted away...
|
||||||
|
|
||||||
@ -493,6 +497,10 @@ class CholNet(nn.Module):
|
|||||||
return self._flat_chol_len
|
return self._flat_chol_len
|
||||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||||
return self._flat_chol_len
|
return self._flat_chol_len
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN:
|
||||||
|
return self.action_dim * 2
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN_BIJECT:
|
||||||
|
return self.action_dim * 2
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
def _parameterize_full(self, params):
|
def _parameterize_full(self, params):
|
||||||
@ -500,6 +508,10 @@ class CholNet(nn.Module):
|
|||||||
return self._chol_from_flat(params)
|
return self._chol_from_flat(params)
|
||||||
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
|
||||||
return self._chol_from_flat_sphe_chol(params)
|
return self._chol_from_flat_sphe_chol(params)
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN:
|
||||||
|
return self._chol_from_givens_params(params, True)
|
||||||
|
elif self.par_type == ParametrizationType.EIGEN_RAW:
|
||||||
|
return self._chol_from_givens_params(params, False)
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
def _chol_from_flat(self, flat_chol):
|
def _chol_from_flat(self, flat_chol):
|
||||||
@ -576,6 +588,24 @@ class CholNet(nn.Module):
|
|||||||
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
|
||||||
dim2=-1)).diag_embed() + chol.triu(1)
|
dim2=-1)).diag_embed() + chol.triu(1)
|
||||||
|
|
||||||
|
def _chol_from_givens_params(self, params, bijection=False):
|
||||||
|
theta, eigenv = params[:self.action_dim], params[self.action_dim:]
|
||||||
|
|
||||||
|
eigenv = self._ensure_positive_func(eigenv)
|
||||||
|
|
||||||
|
if bijection:
|
||||||
|
eigenv = th.cumsum(eigenv, -1)
|
||||||
|
# reverse order, oh well...
|
||||||
|
|
||||||
|
self._givens_rot.theta = theta
|
||||||
|
Q = self._givens_rotator(self._givens_ident)
|
||||||
|
Qinv = Q.transpose(dim0=-2, dim1=-1)
|
||||||
|
|
||||||
|
cov = Q * th.diag(eigenv) * Qinv
|
||||||
|
chol = th.linalg.cholesky(cov)
|
||||||
|
|
||||||
|
return chol
|
||||||
|
|
||||||
def string(self):
|
def string(self):
|
||||||
return '<CholNet />'
|
return '<CholNet />'
|
||||||
|
|
||||||
|
66
metastable_baselines/misc/givens.py
Normal file
66
metastable_baselines/misc/givens.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Source : https://github.com/diadochos/givens-torch
|
||||||
|
# TODO: License
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def G_transpose(D, i, j, theta):
|
||||||
|
"""Generate Givens rotation matrix.
|
||||||
|
>>> G_transpose(2, 0, 1, torch.FloatTensor([[3.1415 / 2]]))
|
||||||
|
tensor([[ 4.6329e-05, 1.0000e+00],
|
||||||
|
[-1.0000e+00, 4.6329e-05]])
|
||||||
|
"""
|
||||||
|
R = torch.eye(D)
|
||||||
|
s, c = torch.sin(theta), torch.cos(theta)
|
||||||
|
R[i, i] = c
|
||||||
|
R[j, j] = c
|
||||||
|
R[i, j] = s
|
||||||
|
R[j, i] = -s
|
||||||
|
return R
|
||||||
|
|
||||||
|
|
||||||
|
class Rotation(nn.Module):
|
||||||
|
def __init__(self, D):
|
||||||
|
"""
|
||||||
|
>>> # Initialized as an identity.
|
||||||
|
>>> A, R = torch.eye(3), Rotation(3)
|
||||||
|
>>> torch.all(A.eq(R(A))).item()
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.D = D
|
||||||
|
self.theta = torch.zeros(
|
||||||
|
(len(list(itertools.combinations(range(self.D), 2))), ))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply rotation.
|
||||||
|
>>> A, R = torch.eye(3), Rotation(3)
|
||||||
|
>>> R.theta = torch.FloatTensor([3.1415 / 2, 0., 0.])
|
||||||
|
>>> R(A)
|
||||||
|
tensor([[ 4.6329e-05, 1.0000e+00, 0.0000e+00],
|
||||||
|
[-1.0000e+00, 4.6329e-05, 0.0000e+00],
|
||||||
|
[ 0.0000e+00, 0.0000e+00, 1.0000e+00]])
|
||||||
|
"""
|
||||||
|
for idx, (i, j) in enumerate(itertools.combinations(range(self.D), 2)):
|
||||||
|
x = torch.matmul(x, G_transpose(self.D, i, j, self.theta[idx]))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def reverse(self, x):
|
||||||
|
"""Apply reverse rotation.
|
||||||
|
>>> A, R = torch.eye(3), Rotation(3)
|
||||||
|
>>> R.weight = torch.FloatTensor([1., 2., 3.])
|
||||||
|
>>> torch.any(A.eq(R(A))).item()
|
||||||
|
True
|
||||||
|
>>> torch.all(A.eq(R.reverse(R(A)))).item()
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
for idx, (i, j) in reversed(list(enumerate(itertools.combinations(range(self.D), 2)))):
|
||||||
|
x = torch.matmul(x, G_transpose(self.D, i, j, -self.theta[idx]))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
Loading…
Reference in New Issue
Block a user