Making dez covariances contextual
This commit is contained in:
parent
155a298e41
commit
a8b9c63965
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from enum import Enum
|
||||
|
||||
import torch as th
|
||||
from torch import nn
|
||||
@ -6,20 +7,61 @@ from torch.distributions import Normal, MultivariateNormal
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
from stable_baselines3.common.distributions import sum_independent_dims
|
||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
|
||||
from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
|
||||
|
||||
class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution):
|
||||
# TODO: Full Cov Parameter
|
||||
# TODO: Contextual Cov
|
||||
# TODO: - Scalar
|
||||
# TODO: - Diag
|
||||
# TODO: - Full
|
||||
# TODO: - Hybrid
|
||||
# TODO: Contextual SDE (Scalar + Diag + Full)
|
||||
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
||||
# TODO: (Support Squased Dists (tanh))
|
||||
|
||||
class Strength(Enum):
|
||||
NONE = 0
|
||||
SCALAR = 1
|
||||
DIAG = 2
|
||||
FULL = 3
|
||||
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
|
||||
@property
|
||||
def foo(self):
|
||||
return self.num
|
||||
|
||||
|
||||
class ParametrizationType(Enum):
|
||||
CHOL = 0
|
||||
ARCHAKOVA = 1
|
||||
|
||||
|
||||
class EnforcePositiveType(Enum):
|
||||
LOG = 0
|
||||
RELU = 1
|
||||
SELU = 2
|
||||
ABS = 3
|
||||
SQ = 4
|
||||
|
||||
|
||||
class UniversalGaussianDistribution(SB3_Distribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix, for continuous actions.
|
||||
Includes contextual parametrization of the covariance matrix.
|
||||
Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions.
|
||||
|
||||
:param action_dim: Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
super(ContextualCovDiagonalGaussianDistribution, self).__init__()
|
||||
super(UniversalGaussianDistribution, self).__init__()
|
||||
self.par_strength = Strength.DIAG
|
||||
self.cov_strength = Strength.DIAG
|
||||
self.par_type = None
|
||||
self.enforce_positive_type = None
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
|
||||
"""
|
||||
@ -32,11 +74,147 @@ class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution):
|
||||
:return:
|
||||
"""
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
log_std = nn.Linear(latent_dim, self.action_dim)
|
||||
if self.contextual_cov:
|
||||
log_std = nn.Linear(latent_dim, self.action_dim)
|
||||
else:
|
||||
log_std = nn.Parameter(
|
||||
th.ones(self.action_dim) * log_std_init, requires_grad=True)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution):
|
||||
:param mean_actions:
|
||||
:param log_std:
|
||||
:return:
|
||||
"""
|
||||
action_std = th.ones_like(mean_actions) * log_std.exp()
|
||||
self.distribution = Normal(mean_actions, action_std)
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must first call the ``proba_distribution()`` method.
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
return self.distribution.rsample()
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(mean_actions, log_std)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Compute the log probability of taking an action
|
||||
given the distribution parameters.
|
||||
|
||||
:param mean_actions:
|
||||
:param log_std:
|
||||
:return:
|
||||
"""
|
||||
actions = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
|
||||
class DiagGaussianDistribution(SB3_Distribution):
|
||||
"""
|
||||
Gaussian distribution with full covariance matrix, for continuous actions.
|
||||
|
||||
:param action_dim: Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int):
|
||||
super(DiagGaussianDistribution, self).__init__()
|
||||
self.action_dim = action_dim
|
||||
self.mean_actions = None
|
||||
self.log_std = None
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
|
||||
"""
|
||||
Create the layers and parameter that represent the distribution:
|
||||
one output will be the mean of the Gaussian, the other parameter will be the
|
||||
standard deviation (log std in fact to allow negative values)
|
||||
|
||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:return:
|
||||
"""
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: allow action dependent std
|
||||
log_std = nn.Parameter(th.ones(self.action_dim)
|
||||
* log_std_init, requires_grad=True)
|
||||
return mean_actions, log_std
|
||||
|
||||
def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
|
||||
"""
|
||||
Create the distribution given its parameters (mean, std)
|
||||
|
||||
:param mean_actions:
|
||||
:param log_std:
|
||||
:return:
|
||||
"""
|
||||
action_std = th.ones_like(mean_actions) * log_std.exp()
|
||||
self.distribution = Normal(mean_actions, action_std)
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must first call the ``proba_distribution()`` method.
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
return self.distribution.rsample()
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
# Update the proba distribution
|
||||
self.proba_distribution(mean_actions, log_std)
|
||||
return self.get_actions(deterministic=deterministic)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Compute the log probability of taking an action
|
||||
given the distribution parameters.
|
||||
|
||||
:param mean_actions:
|
||||
:param log_std:
|
||||
:return:
|
||||
"""
|
||||
actions = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
|
||||
class ContextualSqrtInducedCovDiagonalGaussianDistribution(DiagGaussianDistribution):
|
||||
"""
|
||||
Gaussian distribution induced by its sqrt(cov), for continuous actions.
|
||||
|
||||
@ -60,7 +238,6 @@ class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution):
|
||||
:return:
|
||||
"""
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
# TODO: allow action dependent std
|
||||
log_std = nn.Linear(latent_dim, (self.action_dim, self.action_dim))
|
||||
return mean_actions, log_std
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user