diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index fa5418f..05860c0 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -11,6 +11,7 @@ from stable_baselines3.common.distributions import sum_independent_dims from stable_baselines3.common.distributions import Distribution as SB3_Distribution from stable_baselines3.common.distributions import DiagGaussianDistribution +from ..misc.fakeModule import FakeModule # TODO: Full Cov Parameter # TODO: Contextual Cov @@ -22,6 +23,7 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution # TODO: (SqrtInducedCov (Scalar + Diag + Full)) # TODO: (Support Squased Dists (tanh)) + class Strength(Enum): NONE = 0 SCALAR = 1 @@ -65,7 +67,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.distribution = None - def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Module]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -73,51 +75,72 @@ class UniversalGaussianDistribution(SB3_Distribution): :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation - :return: + :return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix. """ + + # TODO: Rename pseudo_cov to pseudo_chol + mean_actions = nn.Linear(latent_dim, self.action_dim) if self.par_strength == Strength.NONE: if self.cov_strength == Strength.NONE: - pseudo_cov = th.ones(self.action_dim) * log_std_init + pseudo_cov_par = th.ones(self.action_dim) * log_std_init elif self.cov_strength == Strength.SCALAR: - pseudo_cov = th.ones(self.action_dim) * \ + pseudo_cov_par = th.ones(self.action_dim) * \ nn.Parameter(log_std_init, requires_grad=True) elif self.cov_strength == Strength.DIAG: - pseudo_cov = nn.Parameter( + pseudo_cov_par = nn.Parameter( th.ones(self.action_dim) * log_std_init, requires_grad=True) elif self.cov_strength == Strength.FULL: - # Off-axis init? - pseudo_cov = nn.Parameter( + # TODO: This won't work, need to ensure SPD! + # TODO: Off-axis init? + pseudo_cov_par = nn.Parameter( th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) + pseudo_cov = FakeModule(pseudo_cov_par) elif self.par_strength == self.cov_strength: if self.par_strength == Strength.NONE: - pseudo_cov = th.ones(self.action_dim) + pseudo_cov = FakeModule(th.ones(self.action_dim)) elif self.par_strength == Strength.SCALAR: + # TODO: Does it work like this? Test! std = nn.Linear(latent_dim, 1) pseudo_cov = th.ones(self.action_dim) * std elif self.par_strength == Strength.DIAG: pseudo_cov = nn.Linear(latent_dim, self.action_dim) elif self.par_strength == Strength.FULL: - raise Exception("Don't know how to implement yet...") + pseudo_cov = self._parameterize_full(latent_dim) elif self.par_strength > self.cov_strength: raise Exception( 'The parameterization can not be stronger than the actual covariance.') else: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: - factor = nn.Linear(latent_dim, 1) - par_cov = th.ones(self.action_dim) * \ - nn.Parameter(1, requires_grad=True) - pseudo_cov = par_cov * factor[0] + pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim) + elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: + pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim) elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: raise Exception( 'That does not even make any sense...') else: - raise Exception( - 'Programmer-was-to-lazy-to-implement-this-Exception') + raise Exception("This Exception can't happen (I think)") return mean_actions, pseudo_cov + def _parameterize_full(self, latent_dim): + # TODO: Implement various techniques for full parameterization (forcing SPD) + raise Exception( + 'Programmer-was-to-lazy-to-implement-this-Exception') + + def _parameterize_hybrid_from_diag(self, latent_dim): + # TODO: Implement the hybrid-method for DIAG -> FULL (parameters for pearson-correlation-matrix) + raise Exception( + 'Programmer-was-to-lazy-to-implement-this-Exception') + + def _parameterize_hybrid_from_scalar(self, latent_dim): + factor = nn.Linear(latent_dim, 1) + par_cov = th.ones(self.action_dim) * \ + nn.Parameter(1, requires_grad=True) + pseudo_cov = par_cov * factor[0] + return pseudo_cov + def proba_distribution(self, mean_actions: th.Tensor, pseudo_cov: th.Tensor) -> "UniversalGaussianDistribution": """ Create the distribution given its parameters (mean, pseudo_cov) diff --git a/metastable_baselines/misc/fakeModule.py b/metastable_baselines/misc/fakeModule.py new file mode 100644 index 0000000..72f2434 --- /dev/null +++ b/metastable_baselines/misc/fakeModule.py @@ -0,0 +1,20 @@ +import torch as th +from torch import nn + + +class FakeModule(nn.Module): + """ + A torch.nn Module, that drops the input and returns a tensor given at initialization. + Gradients can pass through this Module and affect the given tensor. + """ + # In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule + + def __init__(self, tensor): + super().__init__() + self.tensor = tensor + + def forward(self, x): + return self.tensor + + def string(self): + return ''