Making dez covariances contextual

This commit is contained in:
Dominik Moritz Roth 2022-07-01 11:29:12 +02:00
parent 155a298e41
commit a8b9c63965

View File

@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum
import torch as th
from torch import nn
@ -6,20 +7,61 @@ from torch.distributions import Normal, MultivariateNormal
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.distributions import sum_independent_dims
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
from stable_baselines3.common.distributions import DiagGaussianDistribution
class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution):
# TODO: Full Cov Parameter
# TODO: Contextual Cov
# TODO: - Scalar
# TODO: - Diag
# TODO: - Full
# TODO: - Hybrid
# TODO: Contextual SDE (Scalar + Diag + Full)
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
# TODO: (Support Squased Dists (tanh))
class Strength(Enum):
NONE = 0
SCALAR = 1
DIAG = 2
FULL = 3
def __init__(self, num):
self.num = num
@property
def foo(self):
return self.num
class ParametrizationType(Enum):
CHOL = 0
ARCHAKOVA = 1
class EnforcePositiveType(Enum):
LOG = 0
RELU = 1
SELU = 2
ABS = 3
SQ = 4
class UniversalGaussianDistribution(SB3_Distribution):
"""
Gaussian distribution with diagonal covariance matrix, for continuous actions.
Includes contextual parametrization of the covariance matrix.
Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions.
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int):
super(ContextualCovDiagonalGaussianDistribution, self).__init__()
super(UniversalGaussianDistribution, self).__init__()
self.par_strength = Strength.DIAG
self.cov_strength = Strength.DIAG
self.par_type = None
self.enforce_positive_type = None
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
"""
@ -32,11 +74,147 @@ class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution):
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
if self.contextual_cov:
log_std = nn.Linear(latent_dim, self.action_dim)
else:
log_std = nn.Parameter(
th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
"""
Create the distribution given its parameters (mean, std)
class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution):
:param mean_actions:
:param log_std:
:return:
"""
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions)
return actions, log_prob
class DiagGaussianDistribution(SB3_Distribution):
"""
Gaussian distribution with full covariance matrix, for continuous actions.
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int):
super(DiagGaussianDistribution, self).__init__()
self.action_dim = action_dim
self.mean_actions = None
self.log_std = None
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Parameter(th.ones(self.action_dim)
* log_std_init, requires_grad=True)
return mean_actions, log_std
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
"""
Create the distribution given its parameters (mean, std)
:param mean_actions:
:param log_std:
:return:
"""
action_std = th.ones_like(mean_actions) * log_std.exp()
self.distribution = Normal(mean_actions, action_std)
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions:
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
:param mean_actions:
:param log_std:
:return:
"""
actions = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(actions)
return actions, log_prob
class ContextualSqrtInducedCovDiagonalGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution induced by its sqrt(cov), for continuous actions.
@ -60,7 +238,6 @@ class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution):
:return:
"""
mean_actions = nn.Linear(latent_dim, self.action_dim)
# TODO: allow action dependent std
log_std = nn.Linear(latent_dim, (self.action_dim, self.action_dim))
return mean_actions, log_std