From 3304fd49f6ee4d19ca54ae05daeda538f4f50e34 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 13 Jul 2022 19:38:57 +0200 Subject: [PATCH] Working on UniversalGaussianDistribution --- .../distributions/distributions.py | 344 +++++++++++------- 1 file changed, 220 insertions(+), 124 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 1efd902..bace3df 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from enum import Enum +import gym import torch as th from torch import nn from torch.distributions import Normal, MultivariateNormal @@ -10,10 +11,14 @@ from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.distributions import sum_independent_dims from stable_baselines3.common.distributions import Distribution as SB3_Distribution +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + MultiCategoricalDistribution, + # StateDependentNoiseDistribution, +) from stable_baselines3.common.distributions import DiagGaussianDistribution -from ..misc.fakeModule import FakeModule -from ..misc.distTools import new_dist_like from ..misc.tensor_ops import fill_triangular # TODO: Integrate and Test what I currently have before adding more complexity @@ -34,7 +39,9 @@ class Strength(Enum): class ParametrizationType(Enum): CHOL = 1 SPHERICAL_CHOL = 2 + # Not (yet?) implemented: #GIVENS = 3 + #NNLN_EIGEN = 4 class EnforcePositiveType(Enum): @@ -45,7 +52,7 @@ class EnforcePositiveType(Enum): LOG = (4, th.log) def __init__(self, value, func): - self.value = value + self.val = value self._func = func def apply(self, x): @@ -57,7 +64,7 @@ class ProbSquashingType(Enum): TANH = (1, th.tanh) def __init__(self, value, func): - self.value = value + self.val = value self._func = func def apply(self, x): @@ -92,6 +99,38 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng yield (ps, cs, ept, None) +def make_proba_distribution( + action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None +) -> SB3_Distribution: + """ + Return an instance of Distribution for the correct type of action space + :param action_space: the input action space + :param use_sde: Force the use of StateDependentNoiseDistribution + instead of DiagGaussianDistribution + :param dist_kwargs: Keyword arguments to pass to the probability distribution + :return: the appropriate Distribution object + """ + if dist_kwargs is None: + dist_kwargs = {} + + if isinstance(action_space, gym.spaces.Box): + assert len( + action_space.shape) == 1, "Error: the action space must be a vector" + return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs) + elif isinstance(action_space, gym.spaces.Discrete): + return CategoricalDistribution(action_space.n, **dist_kwargs) + elif isinstance(action_space, gym.spaces.MultiDiscrete): + return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs) + elif isinstance(action_space, gym.spaces.MultiBinary): + return BernoulliDistribution(action_space.n, **dist_kwargs) + else: + raise NotImplementedError( + "Error: probability distribution, not implemented for action space" + f"of type {type(action_space)}." + " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." + ) + + class UniversalGaussianDistribution(SB3_Distribution): """ Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions. @@ -99,8 +138,9 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): super(UniversalGaussianDistribution, self).__init__() + self.action_dim = action_dim self.par_strength = neural_strength self.cov_strength = cov_strength self.par_type = parameterization_type @@ -109,18 +149,27 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution = None - self._flat_chol_len = action_dim * (action_dim + 1) // 2 + if self.prob_squashing_type != ProbSquashingType.NONE: + raise Exception('ProbSquasing is not yet implmenented!') - def new_dist_like_me(self, mean, pseudo_chol): + if use_sde: + raise Exception('SDE is not yet implemented') + + def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor): p = self.distribution - np = new_dist_like(p, mean, pseudo_chol) + if isinstance(p, th.distributions.Normal): + if p.stddev.shape != chol.shape: + chol = th.diagonal(chol, dim1=1, dim2=2) + np = th.distributions.Normal(mean, chol) + elif isinstance(p, th.distributions.MultivariateNormal): + np = th.distributions.MultivariateNormal(mean, scale_tril=chol) new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) new.distribution = np return new - def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: + def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -133,126 +182,16 @@ class UniversalGaussianDistribution(SB3_Distribution): assert std_init >= 0.0, "std can not be initialized to a negative value." - # TODO: Allow chol to be vector when only diagonal. + # TODO: Implement SDE + self.latent_sde_dim = latent_sde_dim mean_actions = nn.Linear(latent_dim, self.action_dim) - - if self.par_strength == Strength.NONE: - if self.cov_strength == Strength.NONE: - pseudo_cov_par = th.ones(self.action_dim) * std_init - elif self.cov_strength == Strength.SCALAR: - pseudo_cov_par = th.ones(self.action_dim) * \ - nn.Parameter(std_init, requires_grad=True) - pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par) - elif self.cov_strength == Strength.DIAG: - pseudo_cov_par = nn.Parameter( - th.ones(self.action_dim) * std_init, requires_grad=True) - pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par) - elif self.cov_strength == Strength.FULL: - # TODO: Init Off-axis differently? - param = nn.Parameter( - th.ones(self._full_params_len) * std_init, requires_grad=True) - pseudo_cov_par = self._parameterize_full(param) - chol = FakeModule(pseudo_cov_par) - elif self.par_strength == self.cov_strength: - if self.par_strength == Strength.SCALAR: - std = nn.Linear(latent_dim, 1) - diag_chol = th.ones(self.action_dim) * std - chol = self._ensure_positive_func(diag_chol) - elif self.par_strength == Strength.DIAG: - diag_chol = nn.Linear(latent_dim, self.action_dim) - chol = self._ensure_positive_func(diag_chol) - elif self.par_strength == Strength.FULL: - params = nn.Linear(latent_dim, self._full_params_len) - chol = self._parameterize_full(params) - elif self.par_strength > self.cov_strength: - raise Exception( - 'The parameterization can not be stronger than the actual covariance.') - else: - if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: - chol = self._parameterize_hybrid_from_scalar(latent_dim) - elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: - chol = self._parameterize_hybrid_from_diag(latent_dim) - elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: - raise Exception( - 'That does not even make any sense...') - else: - raise Exception("This Exception can't happen") + chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength, + self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type) return mean_actions, chol - @property - def _full_params_len(self): - if self.par_type == ParametrizationType.CHOL: - return self._flat_chol_len - elif self.par_type == ParametrizationType.SPHERICAL_CHOL: - return self._flat_chol_len - raise Exception() - - def _parameterize_full(self, params): - if self.par_type == ParametrizationType.CHOL: - return self._chol_from_flat(params) - elif self.par_type == ParametrizationType.SPHERICAL_CHOL: - return self._chol_from_flat_sphe_chol(params) - raise Exception() - - def _parameterize_hybrid_from_diag(self, params): - # TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix) - raise Exception( - 'Programmer-was-to-lazy-to-implement-this-Exception') - - def _parameterize_hybrid_from_scalar(self, latent_dim): - # SCALAR => DIAG - factor = nn.Linear(latent_dim, 1) - par = th.ones(self.action_dim) * \ - nn.Parameter(1, requires_grad=True) - diag_chol = self._ensure_positive_func(par * factor[0]) - return diag_chol - - def _chol_from_flat(self, flat_chol): - chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) - return self._ensure_diagonal_positive(chol) - - def _chol_from_flat_sphe_chol(self, flat_sphe_chol): - pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) - sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( - self._flat_chol_len, -1, -1) - chol = self._chol_from_sphe_chol(sphe_chol) - return chol - - def _chol_from_sphe_chol(self, sphe_chol): - # TODO: Test with batched data - # TODO: Make efficient - # Note: - # We must should ensure: - # S[i,1] > 0 where i = 1..n - # S[i,j] e (0, pi) where i = 2..n, j = 2..i - # We already ensure S > 0 in _chol_from_flat_sphe_chol - # We ensure < pi by applying tanh*pi to all applicable elements - S = sphe_chol - n = self.action_dim - L = th.zeros_like(sphe_chol) - for i in range(n): - for j in range(i): - t = S[i, 1] - for k in range(1, j+1): - t *= th.sin(th.tanh(S[i, k])*pi) - if i != j: - t *= th.cos(th.tanh(S[i, j+1])*pi) - L[i, j] = t - return L - - def _ensure_positive_func(self, x): - return self.enforce_positive_type.apply(x) - - def _ensure_diagonal_positive(self, chol): - if len(chol.shape) == 1: - # If our chol is a vector (representing a diagonal chol) - return self._ensure_positive_func(chol) - return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, - dim2=-1)).diag_embed() + chol.triu(1) - - def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution": + def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, chol) @@ -260,6 +199,8 @@ class UniversalGaussianDistribution(SB3_Distribution): :param chol: :return: """ + # TODO: latent_pi is for SDE, implement. + if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: self.distribution = Normal(mean_actions, chol) elif self.cov_strength in [Strength.FULL]: @@ -306,3 +247,158 @@ class UniversalGaussianDistribution(SB3_Distribution): actions = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(actions) return actions, log_prob + + +class CholNet(nn.Module): + def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType): + super().__init__() + self.latent_dim = latent_dim + self.action_dim = action_dim + + self.par_strength = par_strength + self.cov_strength = cov_strength + self.par_type = par_type + self.enforce_positive_type = enforce_positive_type + self.prob_squashing_type = prob_squashing_type + + self._flat_chol_len = action_dim * (action_dim + 1) // 2 + + # Yes, this is ugly. + # But I don't know how this mess could be elegantly abstracted away... + + if self.par_strength == Strength.NONE: + if self.cov_strength == Strength.NONE: + self.chol = th.ones(self.action_dim) * std_init + elif self.cov_strength == Strength.SCALAR: + self.param = nn.Parameter(std_init, requires_grad=True) + elif self.cov_strength == Strength.DIAG: + self.params = nn.Parameter( + th.ones(self.action_dim) * std_init, requires_grad=True) + elif self.cov_strength == Strength.FULL: + # TODO: Init Off-axis differently? + self.params = nn.Parameter( + th.ones(self._full_params_len) * std_init, requires_grad=True) + elif self.par_strength == self.cov_strength: + if self.par_strength == Strength.SCALAR: + self.std = nn.Linear(latent_dim, 1) + elif self.par_strength == Strength.DIAG: + self.diag_chol = nn.Linear(latent_dim, self.action_dim) + elif self.par_strength == Strength.FULL: + self.params = nn.Linear(latent_dim, self._full_params_len) + elif self.par_strength > self.cov_strength: + raise Exception( + 'The parameterization can not be stronger than the actual covariance.') + else: + if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: + self.factor = nn.Linear(latent_dim, 1) + self.param = nn.Parameter(1, requires_grad=True) + elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: + # TODO + pass + elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: + # TODO + pass + else: + raise Exception("This Exception can't happen") + + def forward(self, x: th.Tensor) -> th.Tensor: + # Ugly mess pt.2: + if self.par_strength == Strength.NONE: + if self.cov_strength == Strength.NONE: + return self.chol + elif self.cov_strength == Strength.SCALAR: + return self._ensure_positive_func( + th.ones(self.action_dim) * self.param) + elif self.cov_strength == Strength.DIAG: + return self._ensure_positive_func(self.params) + elif self.cov_strength == Strength.FULL: + return self._parameterize_full(self.params) + elif self.par_strength == self.cov_strength: + if self.par_strength == Strength.SCALAR: + std = self.std(x) + diag_chol = th.ones(self.action_dim) * std + return self._ensure_positive_func(diag_chol) + elif self.par_strength == Strength.DIAG: + diag_chol = self.diag_chol(x) + return self._ensure_positive_func(diag_chol) + elif self.par_strength == Strength.FULL: + params = self.params(x) + return self._parameterize_full(params) + else: + if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: + factor = self.factor(x) + diag_chol = self._ensure_positive_func( + th.ones(self.action_dim) * self.param * factor[0]) + return diag_chol + elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: + pass + # TODO + elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: + # TODO + pass + raise Exception() + + @property + def _full_params_len(self): + if self.par_type == ParametrizationType.CHOL: + return self._flat_chol_len + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + return self._flat_chol_len + raise Exception() + + def _parameterize_full(self, params): + if self.par_type == ParametrizationType.CHOL: + return self._chol_from_flat(params) + elif self.par_type == ParametrizationType.SPHERICAL_CHOL: + return self._chol_from_flat_sphe_chol(params) + raise Exception() + + def _chol_from_flat(self, flat_chol): + chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1) + return self._ensure_diagonal_positive(chol) + + def _chol_from_flat_sphe_chol(self, flat_sphe_chol): + pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol) + sphe_chol = fill_triangular(pos_flat_sphe_chol).expand( + self._flat_chol_len, -1, -1) + chol = self._chol_from_sphe_chol(sphe_chol) + return chol + + def _chol_from_sphe_chol(self, sphe_chol): + # TODO: Test with batched data + # TODO: Make efficient more + # Note: + # We must should ensure: + # S[i,1] > 0 where i = 1..n + # S[i,j] e (0, pi) where i = 2..n, j = 2..i + # We already ensure S > 0 in _chol_from_flat_sphe_chol + # We ensure < pi by applying tanh*pi to all applicable elements + S = sphe_chol + n = self.action_dim + L = th.zeros_like(sphe_chol) + for i in range(n): + for j in range(i): + t = S[i, 1] + for k in range(1, j+1): + t *= th.sin(th.tanh(S[i, k])*pi) + if i != j: + t *= th.cos(th.tanh(S[i, j+1])*pi) + L[i, j] = t + return L + + def _ensure_positive_func(self, x): + return self.enforce_positive_type.apply(x) + + def _ensure_diagonal_positive(self, chol): + if len(chol.shape) == 1: + # If our chol is a vector (representing a diagonal chol) + return self._ensure_positive_func(chol) + return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2, + dim2=-1)).diag_embed() + chol.triu(1) + + def string(self): + # TODO + return '' + + +AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution]