Fixing ser/deser bug (cloudpickle cant handle some enums)

This commit is contained in:
Dominik Moritz Roth 2022-07-15 18:45:38 +02:00
parent a86d19053d
commit 2e0f46b0f3

View File

@ -46,31 +46,26 @@ class ParametrizationType(Enum):
class EnforcePositiveType(Enum):
# TODO: Allow custom params for softplus?
NONE = (0, nn.Identity())
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
ABS = (2, th.abs)
RELU = (3, nn.ReLU(inplace=False))
LOG = (4, th.log)
# This need to be implemented in this ugly fashion,
# because cloudpickle does not like more complex enums
def __init__(self, value, func):
self.val = value
self._func = func
NONE = 0
SOFTPLUS = 1
ABS = 2
RELU = 3
LOG = 4
def apply(self, x):
return self._func(x)
# aaaaaa
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
class ProbSquashingType(Enum):
NONE = (0, nn.Identity())
TANH = (1, th.tanh)
def __init__(self, value, func):
self.val = value
self._func = func
NONE = 0
TANH = 1
def apply(self, x):
return self._func(x)
return [nn.Identity(), th.tanh][self.value](x)
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
@ -291,13 +286,14 @@ class CholNet(nn.Module):
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
elif self.par_strength == Strength.FULL:
self.params = nn.Linear(latent_dim, self._full_params_len)
elif self.par_strength > self.cov_strength:
elif self.par_strength.value > self.cov_strength.value:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
self.factor = nn.Linear(latent_dim, 1)
self.param = nn.Parameter(1, requires_grad=True)
self.param = nn.Parameter(
th.ones(self.action_dim), requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO
pass
@ -334,7 +330,7 @@ class CholNet(nn.Module):
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
factor = self.factor(x)
diag_chol = self._ensure_positive_func(
th.ones(self.action_dim) * self.param * factor[0])
self.param * factor[0])
return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
pass