diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 421926f..ed282d6 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -46,31 +46,26 @@ class ParametrizationType(Enum): class EnforcePositiveType(Enum): - # TODO: Allow custom params for softplus? - NONE = (0, nn.Identity()) - SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20)) - ABS = (2, th.abs) - RELU = (3, nn.ReLU(inplace=False)) - LOG = (4, th.log) + # This need to be implemented in this ugly fashion, + # because cloudpickle does not like more complex enums - def __init__(self, value, func): - self.val = value - self._func = func + NONE = 0 + SOFTPLUS = 1 + ABS = 2 + RELU = 3 + LOG = 4 def apply(self, x): - return self._func(x) + # aaaaaa + return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x) class ProbSquashingType(Enum): - NONE = (0, nn.Identity()) - TANH = (1, th.tanh) - - def __init__(self, value, func): - self.val = value - self._func = func + NONE = 0 + TANH = 1 def apply(self, x): - return self._func(x) + return [nn.Identity(), th.tanh][self.value](x) def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): @@ -291,13 +286,14 @@ class CholNet(nn.Module): self.diag_chol = nn.Linear(latent_dim, self.action_dim) elif self.par_strength == Strength.FULL: self.params = nn.Linear(latent_dim, self._full_params_len) - elif self.par_strength > self.cov_strength: + elif self.par_strength.value > self.cov_strength.value: raise Exception( 'The parameterization can not be stronger than the actual covariance.') else: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: self.factor = nn.Linear(latent_dim, 1) - self.param = nn.Parameter(1, requires_grad=True) + self.param = nn.Parameter( + th.ones(self.action_dim), requires_grad=True) elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: # TODO pass @@ -334,7 +330,7 @@ class CholNet(nn.Module): if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: factor = self.factor(x) diag_chol = self._ensure_positive_func( - th.ones(self.action_dim) * self.param * factor[0]) + self.param * factor[0]) return diag_chol elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: pass