Fixing ser/deser bug (cloudpickle cant handle some enums)
This commit is contained in:
parent
a86d19053d
commit
2e0f46b0f3
@ -46,31 +46,26 @@ class ParametrizationType(Enum):
|
||||
|
||||
|
||||
class EnforcePositiveType(Enum):
|
||||
# TODO: Allow custom params for softplus?
|
||||
NONE = (0, nn.Identity())
|
||||
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
||||
ABS = (2, th.abs)
|
||||
RELU = (3, nn.ReLU(inplace=False))
|
||||
LOG = (4, th.log)
|
||||
# This need to be implemented in this ugly fashion,
|
||||
# because cloudpickle does not like more complex enums
|
||||
|
||||
def __init__(self, value, func):
|
||||
self.val = value
|
||||
self._func = func
|
||||
NONE = 0
|
||||
SOFTPLUS = 1
|
||||
ABS = 2
|
||||
RELU = 3
|
||||
LOG = 4
|
||||
|
||||
def apply(self, x):
|
||||
return self._func(x)
|
||||
# aaaaaa
|
||||
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
||||
|
||||
|
||||
class ProbSquashingType(Enum):
|
||||
NONE = (0, nn.Identity())
|
||||
TANH = (1, th.tanh)
|
||||
|
||||
def __init__(self, value, func):
|
||||
self.val = value
|
||||
self._func = func
|
||||
NONE = 0
|
||||
TANH = 1
|
||||
|
||||
def apply(self, x):
|
||||
return self._func(x)
|
||||
return [nn.Identity(), th.tanh][self.value](x)
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||
@ -291,13 +286,14 @@ class CholNet(nn.Module):
|
||||
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
|
||||
elif self.par_strength == Strength.FULL:
|
||||
self.params = nn.Linear(latent_dim, self._full_params_len)
|
||||
elif self.par_strength > self.cov_strength:
|
||||
elif self.par_strength.value > self.cov_strength.value:
|
||||
raise Exception(
|
||||
'The parameterization can not be stronger than the actual covariance.')
|
||||
else:
|
||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||
self.factor = nn.Linear(latent_dim, 1)
|
||||
self.param = nn.Parameter(1, requires_grad=True)
|
||||
self.param = nn.Parameter(
|
||||
th.ones(self.action_dim), requires_grad=True)
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
# TODO
|
||||
pass
|
||||
@ -334,7 +330,7 @@ class CholNet(nn.Module):
|
||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||
factor = self.factor(x)
|
||||
diag_chol = self._ensure_positive_func(
|
||||
th.ones(self.action_dim) * self.param * factor[0])
|
||||
self.param * factor[0])
|
||||
return diag_chol
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user