Working on UniversalGaussianDistribution

This commit is contained in:
Dominik Moritz Roth 2022-07-13 19:38:57 +02:00
parent fae19509bc
commit 3304fd49f6

View File

@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum from enum import Enum
import gym
import torch as th import torch as th
from torch import nn from torch import nn
from torch.distributions import Normal, MultivariateNormal from torch.distributions import Normal, MultivariateNormal
@ -10,10 +11,14 @@ from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.distributions import sum_independent_dims from stable_baselines3.common.distributions import sum_independent_dims
from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from stable_baselines3.common.distributions import (
BernoulliDistribution,
CategoricalDistribution,
MultiCategoricalDistribution,
# StateDependentNoiseDistribution,
)
from stable_baselines3.common.distributions import DiagGaussianDistribution from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.fakeModule import FakeModule
from ..misc.distTools import new_dist_like
from ..misc.tensor_ops import fill_triangular from ..misc.tensor_ops import fill_triangular
# TODO: Integrate and Test what I currently have before adding more complexity # TODO: Integrate and Test what I currently have before adding more complexity
@ -34,7 +39,9 @@ class Strength(Enum):
class ParametrizationType(Enum): class ParametrizationType(Enum):
CHOL = 1 CHOL = 1
SPHERICAL_CHOL = 2 SPHERICAL_CHOL = 2
# Not (yet?) implemented:
#GIVENS = 3 #GIVENS = 3
#NNLN_EIGEN = 4
class EnforcePositiveType(Enum): class EnforcePositiveType(Enum):
@ -45,7 +52,7 @@ class EnforcePositiveType(Enum):
LOG = (4, th.log) LOG = (4, th.log)
def __init__(self, value, func): def __init__(self, value, func):
self.value = value self.val = value
self._func = func self._func = func
def apply(self, x): def apply(self, x):
@ -57,7 +64,7 @@ class ProbSquashingType(Enum):
TANH = (1, th.tanh) TANH = (1, th.tanh)
def __init__(self, value, func): def __init__(self, value, func):
self.value = value self.val = value
self._func = func self._func = func
def apply(self, x): def apply(self, x):
@ -92,6 +99,38 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
yield (ps, cs, ept, None) yield (ps, cs, ept, None)
def make_proba_distribution(
action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
) -> SB3_Distribution:
"""
Return an instance of Distribution for the correct type of action space
:param action_space: the input action space
:param use_sde: Force the use of StateDependentNoiseDistribution
instead of DiagGaussianDistribution
:param dist_kwargs: Keyword arguments to pass to the probability distribution
:return: the appropriate Distribution object
"""
if dist_kwargs is None:
dist_kwargs = {}
if isinstance(action_space, gym.spaces.Box):
assert len(
action_space.shape) == 1, "Error: the action space must be a vector"
return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs)
elif isinstance(action_space, gym.spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiDiscrete):
return MultiCategoricalDistribution(action_space.nvec, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
"Error: probability distribution, not implemented for action space"
f"of type {type(action_space)}."
" Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
)
class UniversalGaussianDistribution(SB3_Distribution): class UniversalGaussianDistribution(SB3_Distribution):
""" """
Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions. Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions.
@ -99,8 +138,9 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param action_dim: Dimension of the action space. :param action_dim: Dimension of the action space.
""" """
def __init__(self, action_dim: int, neural_strength=Strength.DIAG, cov_strength=Strength.DIAG, parameterization_type=Strength.CHOL, enforce_positive_type=EnforcePositiveType.ABS, prob_squashing_type=ProbSquashingType.TANH): def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
super(UniversalGaussianDistribution, self).__init__() super(UniversalGaussianDistribution, self).__init__()
self.action_dim = action_dim
self.par_strength = neural_strength self.par_strength = neural_strength
self.cov_strength = cov_strength self.cov_strength = cov_strength
self.par_type = parameterization_type self.par_type = parameterization_type
@ -109,18 +149,27 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.distribution = None self.distribution = None
self._flat_chol_len = action_dim * (action_dim + 1) // 2 if self.prob_squashing_type != ProbSquashingType.NONE:
raise Exception('ProbSquasing is not yet implmenented!')
def new_dist_like_me(self, mean, pseudo_chol): if use_sde:
raise Exception('SDE is not yet implemented')
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
p = self.distribution p = self.distribution
np = new_dist_like(p, mean, pseudo_chol) if isinstance(p, th.distributions.Normal):
if p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
np = th.distributions.Normal(mean, chol)
elif isinstance(p, th.distributions.MultivariateNormal):
np = th.distributions.MultivariateNormal(mean, scale_tril=chol)
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
new.distribution = np new.distribution = np
return new return new
def proba_distribution_net(self, latent_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: def proba_distribution_net(self, latent_dim: int, latent_sde_dim: int, std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]:
""" """
Create the layers and parameter that represent the distribution: Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the one output will be the mean of the Gaussian, the other parameter will be the
@ -133,126 +182,16 @@ class UniversalGaussianDistribution(SB3_Distribution):
assert std_init >= 0.0, "std can not be initialized to a negative value." assert std_init >= 0.0, "std can not be initialized to a negative value."
# TODO: Allow chol to be vector when only diagonal. # TODO: Implement SDE
self.latent_sde_dim = latent_sde_dim
mean_actions = nn.Linear(latent_dim, self.action_dim) mean_actions = nn.Linear(latent_dim, self.action_dim)
chol = CholNet(latent_dim, self.action_dim, std_init, self.par_strength,
if self.par_strength == Strength.NONE: self.cov_strength, self.par_type, self.enforce_positive_type, self.prob_squashing_type)
if self.cov_strength == Strength.NONE:
pseudo_cov_par = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR:
pseudo_cov_par = th.ones(self.action_dim) * \
nn.Parameter(std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.DIAG:
pseudo_cov_par = nn.Parameter(
th.ones(self.action_dim) * std_init, requires_grad=True)
pseudo_cov_par = self._ensure_positive_func(pseudo_cov_par)
elif self.cov_strength == Strength.FULL:
# TODO: Init Off-axis differently?
param = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
pseudo_cov_par = self._parameterize_full(param)
chol = FakeModule(pseudo_cov_par)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.SCALAR:
std = nn.Linear(latent_dim, 1)
diag_chol = th.ones(self.action_dim) * std
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.DIAG:
diag_chol = nn.Linear(latent_dim, self.action_dim)
chol = self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.FULL:
params = nn.Linear(latent_dim, self._full_params_len)
chol = self._parameterize_full(params)
elif self.par_strength > self.cov_strength:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
chol = self._parameterize_hybrid_from_scalar(latent_dim)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
chol = self._parameterize_hybrid_from_diag(latent_dim)
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
raise Exception(
'That does not even make any sense...')
else:
raise Exception("This Exception can't happen")
return mean_actions, chol return mean_actions, chol
@property def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor, latent_pi: nn.Module) -> "UniversalGaussianDistribution":
def _full_params_len(self):
if self.par_type == ParametrizationType.CHOL:
return self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._flat_chol_len
raise Exception()
def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params)
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._chol_from_flat_sphe_chol(params)
raise Exception()
def _parameterize_hybrid_from_diag(self, params):
# TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix)
raise Exception(
'Programmer-was-to-lazy-to-implement-this-Exception')
def _parameterize_hybrid_from_scalar(self, latent_dim):
# SCALAR => DIAG
factor = nn.Linear(latent_dim, 1)
par = th.ones(self.action_dim) * \
nn.Parameter(1, requires_grad=True)
diag_chol = self._ensure_positive_func(par * factor[0])
return diag_chol
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Test with batched data
# TODO: Make efficient
# Note:
# We must should ensure:
# S[i,1] > 0 where i = 1..n
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
S = sphe_chol
n = self.action_dim
L = th.zeros_like(sphe_chol)
for i in range(n):
for j in range(i):
t = S[i, 1]
for k in range(1, j+1):
t *= th.sin(th.tanh(S[i, k])*pi)
if i != j:
t *= th.cos(th.tanh(S[i, j+1])*pi)
L[i, j] = t
return L
def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x)
def _ensure_diagonal_positive(self, chol):
if len(chol.shape) == 1:
# If our chol is a vector (representing a diagonal chol)
return self._ensure_positive_func(chol)
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + chol.triu(1)
def proba_distribution(self, mean_actions: th.Tensor, chol: th.Tensor) -> "UniversalGaussianDistribution":
""" """
Create the distribution given its parameters (mean, chol) Create the distribution given its parameters (mean, chol)
@ -260,6 +199,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param chol: :param chol:
:return: :return:
""" """
# TODO: latent_pi is for SDE, implement.
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
self.distribution = Normal(mean_actions, chol) self.distribution = Normal(mean_actions, chol)
elif self.cov_strength in [Strength.FULL]: elif self.cov_strength in [Strength.FULL]:
@ -306,3 +247,158 @@ class UniversalGaussianDistribution(SB3_Distribution):
actions = self.actions_from_params(mean_actions, log_std) actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions) log_prob = self.log_prob(actions)
return actions, log_prob return actions, log_prob
class CholNet(nn.Module):
def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: Strength, cov_strength: Strength, par_type: ParametrizationType, enforce_positive_type: EnforcePositiveType, prob_squashing_type: ProbSquashingType):
super().__init__()
self.latent_dim = latent_dim
self.action_dim = action_dim
self.par_strength = par_strength
self.cov_strength = cov_strength
self.par_type = par_type
self.enforce_positive_type = enforce_positive_type
self.prob_squashing_type = prob_squashing_type
self._flat_chol_len = action_dim * (action_dim + 1) // 2
# Yes, this is ugly.
# But I don't know how this mess could be elegantly abstracted away...
if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE:
self.chol = th.ones(self.action_dim) * std_init
elif self.cov_strength == Strength.SCALAR:
self.param = nn.Parameter(std_init, requires_grad=True)
elif self.cov_strength == Strength.DIAG:
self.params = nn.Parameter(
th.ones(self.action_dim) * std_init, requires_grad=True)
elif self.cov_strength == Strength.FULL:
# TODO: Init Off-axis differently?
self.params = nn.Parameter(
th.ones(self._full_params_len) * std_init, requires_grad=True)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.SCALAR:
self.std = nn.Linear(latent_dim, 1)
elif self.par_strength == Strength.DIAG:
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
elif self.par_strength == Strength.FULL:
self.params = nn.Linear(latent_dim, self._full_params_len)
elif self.par_strength > self.cov_strength:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
self.factor = nn.Linear(latent_dim, 1)
self.param = nn.Parameter(1, requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO
pass
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
# TODO
pass
else:
raise Exception("This Exception can't happen")
def forward(self, x: th.Tensor) -> th.Tensor:
# Ugly mess pt.2:
if self.par_strength == Strength.NONE:
if self.cov_strength == Strength.NONE:
return self.chol
elif self.cov_strength == Strength.SCALAR:
return self._ensure_positive_func(
th.ones(self.action_dim) * self.param)
elif self.cov_strength == Strength.DIAG:
return self._ensure_positive_func(self.params)
elif self.cov_strength == Strength.FULL:
return self._parameterize_full(self.params)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.SCALAR:
std = self.std(x)
diag_chol = th.ones(self.action_dim) * std
return self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.DIAG:
diag_chol = self.diag_chol(x)
return self._ensure_positive_func(diag_chol)
elif self.par_strength == Strength.FULL:
params = self.params(x)
return self._parameterize_full(params)
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
factor = self.factor(x)
diag_chol = self._ensure_positive_func(
th.ones(self.action_dim) * self.param * factor[0])
return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
pass
# TODO
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
# TODO
pass
raise Exception()
@property
def _full_params_len(self):
if self.par_type == ParametrizationType.CHOL:
return self._flat_chol_len
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._flat_chol_len
raise Exception()
def _parameterize_full(self, params):
if self.par_type == ParametrizationType.CHOL:
return self._chol_from_flat(params)
elif self.par_type == ParametrizationType.SPHERICAL_CHOL:
return self._chol_from_flat_sphe_chol(params)
raise Exception()
def _chol_from_flat(self, flat_chol):
chol = fill_triangular(flat_chol).expand(self._flat_chol_len, -1, -1)
return self._ensure_diagonal_positive(chol)
def _chol_from_flat_sphe_chol(self, flat_sphe_chol):
pos_flat_sphe_chol = self._ensure_positive_func(flat_sphe_chol)
sphe_chol = fill_triangular(pos_flat_sphe_chol).expand(
self._flat_chol_len, -1, -1)
chol = self._chol_from_sphe_chol(sphe_chol)
return chol
def _chol_from_sphe_chol(self, sphe_chol):
# TODO: Test with batched data
# TODO: Make efficient more
# Note:
# We must should ensure:
# S[i,1] > 0 where i = 1..n
# S[i,j] e (0, pi) where i = 2..n, j = 2..i
# We already ensure S > 0 in _chol_from_flat_sphe_chol
# We ensure < pi by applying tanh*pi to all applicable elements
S = sphe_chol
n = self.action_dim
L = th.zeros_like(sphe_chol)
for i in range(n):
for j in range(i):
t = S[i, 1]
for k in range(1, j+1):
t *= th.sin(th.tanh(S[i, k])*pi)
if i != j:
t *= th.cos(th.tanh(S[i, j+1])*pi)
L[i, j] = t
return L
def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x)
def _ensure_diagonal_positive(self, chol):
if len(chol.shape) == 1:
# If our chol is a vector (representing a diagonal chol)
return self._ensure_positive_func(chol)
return chol.tril(-1) + self._ensure_positive_func(chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + chol.triu(1)
def string(self):
# TODO
return '<CholNet />'
AnyDistribution = Union[SB3_Distribution, UniversalGaussianDistribution]