diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py index df9e373..19a495b 100644 --- a/priorConditionedAnnealing/noise.py +++ b/priorConditionedAnnealing/noise.py @@ -46,6 +46,9 @@ class White_Noise(): shape = self.known_shape return th.Tensor(np.random.normal(0, 1, shape)) + def reset(self): + pass + def get_colored_noise(beta, known_shape=None): if beta == 0: diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 3bdf504..3589987 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -5,19 +5,20 @@ import scipy.spatial from torch import nn from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import sum_independent_dims -from torch.distributions import Normal +from torch.distributions import Normal, MultivariateNormal import torch.nn.functional as F - from priorConditionedAnnealing import noise, kernel +from priorConditionedAnnealing.tensor_ops import fill_triangular, fill_triangular_inverse class Par_Strength(Enum): SCALAR = 'SCALAR' DIAG = 'DIAG' + FULL = 'FULL' CONT_SCALAR = 'CONT_SCALAR' CONT_DIAG = 'CONT_DIAG' CONT_HYBRID = 'CONT_HYBRID' - + CONT_FULL = 'CONT_FULL' class EnforcePositiveType(Enum): # This need to be implemented in this ugly fashion, @@ -31,7 +32,7 @@ class EnforcePositiveType(Enum): def apply(self, x): # aaaaaa - return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x) + return [nn.Identity(), nn.Softplus(beta=10, threshold=2), th.abs, nn.ReLU(inplace=False), th.log][self.value](x) class Avaible_Kernel_Funcs(Enum): @@ -98,6 +99,7 @@ class PCA_Distribution(SB3_Distribution): epsilon: float = 1e-6, skip_conditioning: bool = False, temporal_gradient_emission: bool = False, + msqrt_induces_full: bool = False, Base_Noise=noise.White_Noise, ): super().__init__() @@ -111,9 +113,12 @@ class PCA_Distribution(SB3_Distribution): self.epsilon = epsilon self.skip_conditioning = skip_conditioning self.temporal_gradient_emission = temporal_gradient_emission + self.msqrt_induces_full = msqrt_induces_full self.base_noise = cast_to_Noise(Base_Noise, (n_envs, action_dim)) + assert not (not skip_conditioning and self.is_full()), 'Conditioning full Covariances not yet implemented' + # Premature optimization is the root of all evil self._build_conditioner() # *Optimizes it anyways* @@ -126,8 +131,13 @@ class PCA_Distribution(SB3_Distribution): def proba_distribution( self, mean_actions: th.Tensor, std_actions: th.Tensor) -> SB3_Distribution: - self.distribution = Normal( - mean_actions, std_actions) + if self.is_full(): + self.distribution = MultivariateNormal(mean_actions, scale_tril=std_actions, validate_args=not self.msqrt_induces_full) + #self.distribution.scale = th.diagonal(std_actions, dim1=-2, dim2=-1) + self.distribution._mark_mSqrt = self.msqrt_induces_full + else: + self.distribution = Normal( + mean_actions, std_actions) return self def log_prob(self, actions: th.Tensor) -> th.Tensor: @@ -156,24 +166,29 @@ class PCA_Distribution(SB3_Distribution): return self.mode() return self.sample(traj=trajectory) - def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor: + def sample(self, traj: th.Tensor, f_sigma: float = 1.0, epsilon=None) -> th.Tensor: assert self.skip_conditioning or type(traj) != type(None), 'A past trajectory has to be supplied if conditinoning is performed' - pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu() - rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) + pi_mean, pi_decomp = self.distribution.mean.cpu(), self.distribution.scale_tril.cpu() if self.is_full() else self.distribution.scale.cpu() + rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_decomp) rho_std = rho_std * f_sigma - eta = self._get_rigged(pi_mean, pi_std, + eta = self._get_rigged(pi_mean, pi_decomp, rho_mean, rho_std, epsilon) # reparameterization with rigged samples - actions = pi_mean + pi_std * eta + if self.is_full(): + actions = pi_mean + th.matmul(pi_decomp, eta.unsqueeze(-1)).squeeze(-1) + else: + actions = pi_mean + pi_decomp * eta self.gaussian_actions = actions return actions def is_contextual(self): - return True # TODO: Remove, when bug for non-contextual is fixed - # Always returning True will merely waste cpu cycles - return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG] + return self.par_strength in [Par_Strength.CONT_SCALAR, Par_Strength.CONT_DIAG, Par_Strength.CONT_HYBRID, Par_Strength.CONT_FULL] + + def is_full(self): + return self.par_strength in [Par_Strength.FULL, Par_Strength.CONT_FULL] + def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None): # Ugly function to ensure, that the gradients flow as intended for each modus operandi @@ -251,7 +266,8 @@ class PCA_Distribution(SB3_Distribution): # S_{ij} = \frac{1}{D_j} \left( A_{ij} - \sum_{k=1}^{j-1} S_{ik} S_{jk} D_k \right), \qquad\text{for } i>j # https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png # This way conditioning of the GP can be done in O(dim(A)) time. - if not self.is_contextual(): + if not self.is_contextual() and False: + # Always assuming contextual will merely waste cpu cycles # TODO: fix, this does not work # safe inplace self.conditioner[-1, - @@ -289,7 +305,7 @@ class PCA_Distribution(SB3_Distribution): class StdNet(nn.Module): - def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std): + def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float, return_log_std: bool): super().__init__() self.action_dim = action_dim self.latent_dim = latent_dim @@ -299,8 +315,6 @@ class StdNet(nn.Module): self.epsilon = epsilon self.return_log_std = return_log_std - if return_log_std: - self.enforce_positive_type = EnforcePositiveType.NONE if self.par_strength == Par_Strength.SCALAR: self.param = nn.Parameter( @@ -308,6 +322,11 @@ class StdNet(nn.Module): elif self.par_strength == Par_Strength.DIAG: self.param = nn.Parameter( th.Tensor(th.ones(action_dim)*std_init), requires_grad=True) + elif self.par_strength == Par_Strength.FULL: + ident = th.eye(action_dim)*std_init + ident_chol = fill_triangular_inverse(ident) + self.param = nn.Parameter( + th.Tensor(ident_chol), requires_grad=True) elif self.par_strength == Par_Strength.CONT_SCALAR: self.net = nn.Linear(latent_dim, 1) elif self.par_strength == Par_Strength.CONT_HYBRID: @@ -316,6 +335,11 @@ class StdNet(nn.Module): th.Tensor(th.ones(action_dim)*std_init), requires_grad=True) elif self.par_strength == Par_Strength.CONT_DIAG: self.net = nn.Linear(latent_dim, self.action_dim) + self.bias = th.ones(action_dim)*self.std_init + elif self.par_strength == Par_Strength.CONT_FULL: + self.net = nn.Linear(latent_dim, action_dim * (action_dim + 1) // 2) + self.bias = fill_triangular_inverse(th.eye(action_dim)*self.std_init) + def forward(self, x: th.Tensor) -> th.Tensor: if self.par_strength == Par_Strength.SCALAR: @@ -323,6 +347,8 @@ class StdNet(nn.Module): th.ones(self.action_dim) * self.param[0]) elif self.par_strength == Par_Strength.DIAG: return self._ensure_positive_func(self.param) + elif self.par_strength == Par_Strength.FULL: + return self._chol_from_flat(self.param) elif self.par_strength == Par_Strength.CONT_SCALAR: cont = self.net(x) diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init @@ -332,14 +358,27 @@ class StdNet(nn.Module): return self._ensure_positive_func(self.param * cont) elif self.par_strength == Par_Strength.CONT_DIAG: cont = self.net(x) - diag_chol = cont * self.std_init + diag_chol = cont + self.bias return self._ensure_positive_func(diag_chol) + elif self.par_strength == Par_Strength.CONT_FULL: + cont = self.net(x) + return self._chol_from_flat(cont + self.bias) raise Exception() def _ensure_positive_func(self, x): return self.enforce_positive_type.apply(x) + self.epsilon + def _chol_from_flat(self, flat_chol): + chol = fill_triangular(flat_chol) + return self._ensure_diagonal_positive(chol) + + def _ensure_diagonal_positive(self, chol): + if len(chol.shape) == 1: + # If our chol is a vector (representing a diagonal chol) + return self._ensure_positive_func(chol) + return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, + dim2=-1)).diag_embed() + chol.triu(1) def string(self): return '' diff --git a/priorConditionedAnnealing/tensor_ops.py b/priorConditionedAnnealing/tensor_ops.py new file mode 100644 index 0000000..0327246 --- /dev/null +++ b/priorConditionedAnnealing/tensor_ops.py @@ -0,0 +1,150 @@ +import torch as th +import numpy as np + + +def fill_triangular(x, upper=False): + """ + The following function is derived from TensorFlow Probability + https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L784 + Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license, + cf. 3rd-party-licenses.txt file in the root directory of this source tree. + Creates a (batch of) triangular matrix from a vector of inputs. + Created matrix can be lower- or upper-triangular. (It is more efficient to + create the matrix as upper or lower, rather than transpose.) + Triangular matrix elements are filled in a clockwise spiral. See example, + below. + If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is + `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., + `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. + Example: + ```python + fill_triangular([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + fill_triangular([1, 2, 3, 4, 5, 6], upper=True) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + ``` + The key trick is to create an upper triangular matrix by concatenating `x` + and a tail of itself, then reshaping. + Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` + from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` + contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` + (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with + the first (`n = 5`) elements removed and reversed: + ```python + x = np.arange(15) + 1 + xc = np.concatenate([x, x[5:][::-1]]) + # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, + # 12, 11, 10, 9, 8, 7, 6]) + # (We add one to the arange result to disambiguate the zeros below the + # diagonal of our upper-triangular matrix from the first entry in `x`.) + # Now, when reshapedlay this out as a matrix: + y = np.reshape(xc, [5, 5]) + # ==> array([[ 1, 2, 3, 4, 5], + # [ 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15], + # [15, 14, 13, 12, 11], + # [10, 9, 8, 7, 6]]) + # Finally, zero the elements below the diagonal: + y = np.triu(y, k=0) + # ==> array([[ 1, 2, 3, 4, 5], + # [ 0, 7, 8, 9, 10], + # [ 0, 0, 13, 14, 15], + # [ 0, 0, 0, 12, 11], + # [ 0, 0, 0, 0, 6]]) + ``` + From this example we see that the resuting matrix is upper-triangular, and + contains all the entries of x, as desired. The rest is details: + - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills + `n / 2` rows and half of an additional row), but the whole scheme still + works. + - If we want a lower triangular matrix instead of an upper triangular, + we remove the first `n` elements from `x` rather than from the reversed + `x`. + For additional comparisons, a pure numpy version of this function can be found + in `distribution_util_test.py`, function `_fill_triangular`. + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + Returns: + tril: `Tensor` with lower (or upper) triangular elements filled from `x`. + Raises: + ValueError: if `x` cannot be mapped to a triangular matrix. + """ + + m = np.int32(x.shape[-1]) + # Formula derived by solving for n: m = n(n+1)/2. + n = np.sqrt(0.25 + 2. * m) - 0.5 + if n != np.floor(n): + raise ValueError('Input right-most shape ({}) does not ' + 'correspond to a triangular matrix.'.format(m)) + n = np.int32(n) + new_shape = x.shape[:-1] + (n, n) + + ndims = len(x.shape) + if upper: + x_list = [x, th.flip(x[..., n:], dims=[ndims - 1])] + else: + x_list = [x[..., n:], th.flip(x, dims=[ndims - 1])] + + x = th.cat(x_list, dim=-1).reshape(new_shape) + x = th.triu(x) if upper else th.tril(x) + return x + + +def fill_triangular_inverse(x, upper=False): + """ + The following function is derived from TensorFlow Probability + https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/math/linalg.py#L934 + Copyright (c) 2018 The TensorFlow Probability Authors, licensed under the Apache-2.0 license, + cf. 3rd-party-licenses.txt file in the root directory of this source tree. + Creates a vector from a (batch of) triangular matrix. + The vector is created from the lower-triangular or upper-triangular portion + depending on the value of the parameter `upper`. + If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is + `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. + Example: + ```python + fill_triangular_inverse( + [[4, 0, 0], + [6, 5, 0], + [3, 2, 1]]) + # ==> [1, 2, 3, 4, 5, 6] + fill_triangular_inverse( + [[1, 2, 3], + [0, 5, 6], + [0, 0, 4]], upper=True) + # ==> [1, 2, 3, 4, 5, 6] + ``` + Args: + x: `Tensor` representing lower (or upper) triangular elements. + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + Returns: + flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower + (or upper) triangular elements from `x`. + """ + + n = np.int32(x.shape[-1]) + m = np.int32((n * (n + 1)) // 2) + + ndims = len(x.shape) + if upper: + initial_elements = x[..., 0, :] + triangular_part = x[..., 1:, :] + else: + initial_elements = th.flip(x[..., -1, :], dims=[ndims - 2]) + triangular_part = x[..., :-1, :] + + rotated_triangular_portion = th.flip( + th.flip(triangular_part, dims=[ndims - 1]), dims=[ndims - 2]) + consolidated_matrix = triangular_part + rotated_triangular_portion + + end_sequence = consolidated_matrix.reshape(x.shape[:-2] + (n * (n - 1),)) + + y = th.cat([initial_elements, end_sequence[..., :m - n]], dim=-1) + return y \ No newline at end of file