From e4440428f8340d1c89193df3bb08c6bf6d21c3ca Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 11 Jul 2022 11:55:23 +0200 Subject: [PATCH] Working on SDC --- .../distributions/distributions.py | 67 ++++++++++++------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 45f7eec..1494424 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -28,31 +28,39 @@ class Strength(Enum): DIAG = 2 FULL = 3 - # def __init__(self, num): - # self.num = num - - # @property - # def foo(self): - # return self.num - class ParametrizationType(Enum): + # Currently only Chol is implemented CHOL = 1 - SPHERICAL_CHOL = 2 - GIVENS = 3 + #SPHERICAL_CHOL = 2 + #GIVENS = 3 class EnforcePositiveType(Enum): - SOFTPLUS = 1 - ABS = 2 - RELU = 3 - SELU = 4 - LOG = 5 + # TODO: Allow custom params for softplus? + SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20)) + ABS = (2, th.abs) + RELU = (3, nn.ReLU(inplace=False)) + LOG = (4, th.log) + + def __init__(self, value, func): + self.value = value + self._func = func + + def apply(self, x): + return self._func(x) class ProbSquashingType(Enum): - NONE = 0 - TANH = 1 + NONE = (0, nn.Identity()) + TANH = (1, th.tanh) + + def __init__(self, value, func): + self.value = value + self._func = func + + def apply(self, x): + return self._func(x) def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): @@ -118,10 +126,10 @@ class UniversalGaussianDistribution(SB3_Distribution): :param latent_dim: Dimension of the last layer of the policy (before the action layer) :param log_std_init: Initial value for the log standard deviation - :return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix. + :return: We return two nn.Modules (mean, chol). """ - # TODO: Rename pseudo_cov to pseudo_chol + # TODO: Allow chol to be vector when only diagonal. mean_actions = nn.Linear(latent_dim, self.action_dim) @@ -139,33 +147,33 @@ class UniversalGaussianDistribution(SB3_Distribution): # TODO: Off-axis init? pseudo_cov_par = nn.Parameter( th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) - pseudo_cov = FakeModule(pseudo_cov_par) + chol = FakeModule(pseudo_cov_par) elif self.par_strength == self.cov_strength: if self.par_strength == Strength.NONE: - pseudo_cov = FakeModule(th.ones(self.action_dim)) + chol = FakeModule(th.ones(self.action_dim)) elif self.par_strength == Strength.SCALAR: # TODO: Does it work like this? Test! std = nn.Linear(latent_dim, 1) - pseudo_cov = th.ones(self.action_dim) * std + chol = th.ones(self.action_dim) * std elif self.par_strength == Strength.DIAG: - pseudo_cov = nn.Linear(latent_dim, self.action_dim) + chol = nn.Linear(latent_dim, self.action_dim) elif self.par_strength == Strength.FULL: - pseudo_cov = self._parameterize_full(latent_dim) + chol = self._parameterize_full(latent_dim) elif self.par_strength > self.cov_strength: raise Exception( 'The parameterization can not be stronger than the actual covariance.') else: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: - pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim) + chol = self._parameterize_hybrid_from_scalar(latent_dim) elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: - pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim) + chol = self._parameterize_hybrid_from_diag(latent_dim) elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: raise Exception( 'That does not even make any sense...') else: raise Exception("This Exception can't happen (I think)") - return mean_actions, pseudo_cov + return mean_actions, chol def _parameterize_full(self, latent_dim): # TODO: Implement various techniques for full parameterization (forcing SPD) @@ -177,6 +185,13 @@ class UniversalGaussianDistribution(SB3_Distribution): raise Exception( 'Programmer-was-to-lazy-to-implement-this-Exception') + def _ensure_positive_func(self, x): + return self.enforce_positive_type.apply(x) + + def _ensure_diagonal_positive(self, pseudo_chol): + pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2, + dim2=-1)).diag_embed() + pseudo_chol.triu(1) + def _parameterize_hybrid_from_scalar(self, latent_dim): factor = nn.Linear(latent_dim, 1) par_cov = th.ones(self.action_dim) * \