From a8b9c639656f80e01e8b8b31d7fce49fd642fcda Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 1 Jul 2022 11:29:12 +0200 Subject: [PATCH] Making dez covariances contextual --- .../distributions/distributions.py | 191 +++++++++++++++++- 1 file changed, 184 insertions(+), 7 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index b7c32ab..fe4005f 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional, Tuple, Union +from enum import Enum import torch as th from torch import nn @@ -6,20 +7,61 @@ from torch.distributions import Normal, MultivariateNormal from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.distributions import sum_independent_dims from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import DiagGaussianDistribution -class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution): +# TODO: Full Cov Parameter +# TODO: Contextual Cov +# TODO: - Scalar +# TODO: - Diag +# TODO: - Full +# TODO: - Hybrid +# TODO: Contextual SDE (Scalar + Diag + Full) +# TODO: (SqrtInducedCov (Scalar + Diag + Full)) +# TODO: (Support Squased Dists (tanh)) + +class Strength(Enum): + NONE = 0 + SCALAR = 1 + DIAG = 2 + FULL = 3 + + def __init__(self, num): + self.num = num + + @property + def foo(self): + return self.num + + +class ParametrizationType(Enum): + CHOL = 0 + ARCHAKOVA = 1 + + +class EnforcePositiveType(Enum): + LOG = 0 + RELU = 1 + SELU = 2 + ABS = 3 + SQ = 4 + + +class UniversalGaussianDistribution(SB3_Distribution): """ - Gaussian distribution with diagonal covariance matrix, for continuous actions. - Includes contextual parametrization of the covariance matrix. + Gaussian distribution with configurable covariance matrix shape and optional contextual parametrization mechanism, for continuous actions. :param action_dim: Dimension of the action space. """ def __init__(self, action_dim: int): - super(ContextualCovDiagonalGaussianDistribution, self).__init__() + super(UniversalGaussianDistribution, self).__init__() + self.par_strength = Strength.DIAG + self.cov_strength = Strength.DIAG + self.par_type = None + self.enforce_positive_type = None def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: """ @@ -32,11 +74,147 @@ class ContextualCovDiagonalGaussianDistribution(DiagGaussianDistribution): :return: """ mean_actions = nn.Linear(latent_dim, self.action_dim) - log_std = nn.Linear(latent_dim, self.action_dim) + if self.contextual_cov: + log_std = nn.Linear(latent_dim, self.action_dim) + else: + log_std = nn.Parameter( + th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution": + """ + Create the distribution given its parameters (mean, std) -class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution): + :param mean_actions: + :param log_std: + :return: + """ + action_std = th.ones_like(mean_actions) * log_std.exp() + self.distribution = Normal(mean_actions, action_std) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: + :return: + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + return self.distribution.rsample() + + def mode(self) -> th.Tensor: + return self.distribution.mean + + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(mean_actions, log_std) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + Compute the log probability of taking an action + given the distribution parameters. + + :param mean_actions: + :param log_std: + :return: + """ + actions = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class DiagGaussianDistribution(SB3_Distribution): + """ + Gaussian distribution with full covariance matrix, for continuous actions. + + :param action_dim: Dimension of the action space. + """ + + def __init__(self, action_dim: int): + super(DiagGaussianDistribution, self).__init__() + self.action_dim = action_dim + self.mean_actions = None + self.log_std = None + + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + """ + Create the layers and parameter that represent the distribution: + one output will be the mean of the Gaussian, the other parameter will be the + standard deviation (log std in fact to allow negative values) + + :param latent_dim: Dimension of the last layer of the policy (before the action layer) + :param log_std_init: Initial value for the log standard deviation + :return: + """ + mean_actions = nn.Linear(latent_dim, self.action_dim) + # TODO: allow action dependent std + log_std = nn.Parameter(th.ones(self.action_dim) + * log_std_init, requires_grad=True) + return mean_actions, log_std + + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution": + """ + Create the distribution given its parameters (mean, std) + + :param mean_actions: + :param log_std: + :return: + """ + action_std = th.ones_like(mean_actions) * log_std.exp() + self.distribution = Normal(mean_actions, action_std) + return self + + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: + :return: + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + return self.distribution.rsample() + + def mode(self) -> th.Tensor: + return self.distribution.mean + + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: + # Update the proba distribution + self.proba_distribution(mean_actions, log_std) + return self.get_actions(deterministic=deterministic) + + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + """ + Compute the log probability of taking an action + given the distribution parameters. + + :param mean_actions: + :param log_std: + :return: + """ + actions = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(actions) + return actions, log_prob + + +class ContextualSqrtInducedCovDiagonalGaussianDistribution(DiagGaussianDistribution): """ Gaussian distribution induced by its sqrt(cov), for continuous actions. @@ -60,7 +238,6 @@ class ContextualSqrtCovDiagonalGaussianDistribution(DiagGaussianDistribution): :return: """ mean_actions = nn.Linear(latent_dim, self.action_dim) - # TODO: allow action dependent std log_std = nn.Linear(latent_dim, (self.action_dim, self.action_dim)) return mean_actions, log_std