Fixing ser/deser bug (cloudpickle cant handle some enums)

This commit is contained in:
Dominik Moritz Roth 2022-07-15 18:45:38 +02:00
parent a86d19053d
commit 2e0f46b0f3

View File

@ -46,31 +46,26 @@ class ParametrizationType(Enum):
class EnforcePositiveType(Enum): class EnforcePositiveType(Enum):
# TODO: Allow custom params for softplus? # This need to be implemented in this ugly fashion,
NONE = (0, nn.Identity()) # because cloudpickle does not like more complex enums
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
ABS = (2, th.abs)
RELU = (3, nn.ReLU(inplace=False))
LOG = (4, th.log)
def __init__(self, value, func): NONE = 0
self.val = value SOFTPLUS = 1
self._func = func ABS = 2
RELU = 3
LOG = 4
def apply(self, x): def apply(self, x):
return self._func(x) # aaaaaa
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
class ProbSquashingType(Enum): class ProbSquashingType(Enum):
NONE = (0, nn.Identity()) NONE = 0
TANH = (1, th.tanh) TANH = 1
def __init__(self, value, func):
self.val = value
self._func = func
def apply(self, x): def apply(self, x):
return self._func(x) return [nn.Identity(), th.tanh][self.value](x)
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
@ -291,13 +286,14 @@ class CholNet(nn.Module):
self.diag_chol = nn.Linear(latent_dim, self.action_dim) self.diag_chol = nn.Linear(latent_dim, self.action_dim)
elif self.par_strength == Strength.FULL: elif self.par_strength == Strength.FULL:
self.params = nn.Linear(latent_dim, self._full_params_len) self.params = nn.Linear(latent_dim, self._full_params_len)
elif self.par_strength > self.cov_strength: elif self.par_strength.value > self.cov_strength.value:
raise Exception( raise Exception(
'The parameterization can not be stronger than the actual covariance.') 'The parameterization can not be stronger than the actual covariance.')
else: else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
self.factor = nn.Linear(latent_dim, 1) self.factor = nn.Linear(latent_dim, 1)
self.param = nn.Parameter(1, requires_grad=True) self.param = nn.Parameter(
th.ones(self.action_dim), requires_grad=True)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
# TODO # TODO
pass pass
@ -334,7 +330,7 @@ class CholNet(nn.Module):
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
factor = self.factor(x) factor = self.factor(x)
diag_chol = self._ensure_positive_func( diag_chol = self._ensure_positive_func(
th.ones(self.action_dim) * self.param * factor[0]) self.param * factor[0])
return diag_chol return diag_chol
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
pass pass