Fixing ser/deser bug (cloudpickle cant handle some enums)
This commit is contained in:
parent
a86d19053d
commit
2e0f46b0f3
@ -46,31 +46,26 @@ class ParametrizationType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
# TODO: Allow custom params for softplus?
|
# This need to be implemented in this ugly fashion,
|
||||||
NONE = (0, nn.Identity())
|
# because cloudpickle does not like more complex enums
|
||||||
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
|
||||||
ABS = (2, th.abs)
|
|
||||||
RELU = (3, nn.ReLU(inplace=False))
|
|
||||||
LOG = (4, th.log)
|
|
||||||
|
|
||||||
def __init__(self, value, func):
|
NONE = 0
|
||||||
self.val = value
|
SOFTPLUS = 1
|
||||||
self._func = func
|
ABS = 2
|
||||||
|
RELU = 3
|
||||||
|
LOG = 4
|
||||||
|
|
||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
return self._func(x)
|
# aaaaaa
|
||||||
|
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
class ProbSquashingType(Enum):
|
class ProbSquashingType(Enum):
|
||||||
NONE = (0, nn.Identity())
|
NONE = 0
|
||||||
TANH = (1, th.tanh)
|
TANH = 1
|
||||||
|
|
||||||
def __init__(self, value, func):
|
|
||||||
self.val = value
|
|
||||||
self._func = func
|
|
||||||
|
|
||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
return self._func(x)
|
return [nn.Identity(), th.tanh][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||||
@ -291,13 +286,14 @@ class CholNet(nn.Module):
|
|||||||
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
|
self.diag_chol = nn.Linear(latent_dim, self.action_dim)
|
||||||
elif self.par_strength == Strength.FULL:
|
elif self.par_strength == Strength.FULL:
|
||||||
self.params = nn.Linear(latent_dim, self._full_params_len)
|
self.params = nn.Linear(latent_dim, self._full_params_len)
|
||||||
elif self.par_strength > self.cov_strength:
|
elif self.par_strength.value > self.cov_strength.value:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'The parameterization can not be stronger than the actual covariance.')
|
'The parameterization can not be stronger than the actual covariance.')
|
||||||
else:
|
else:
|
||||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||||
self.factor = nn.Linear(latent_dim, 1)
|
self.factor = nn.Linear(latent_dim, 1)
|
||||||
self.param = nn.Parameter(1, requires_grad=True)
|
self.param = nn.Parameter(
|
||||||
|
th.ones(self.action_dim), requires_grad=True)
|
||||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
@ -334,7 +330,7 @@ class CholNet(nn.Module):
|
|||||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||||
factor = self.factor(x)
|
factor = self.factor(x)
|
||||||
diag_chol = self._ensure_positive_func(
|
diag_chol = self._ensure_positive_func(
|
||||||
th.ones(self.action_dim) * self.param * factor[0])
|
self.param * factor[0])
|
||||||
return diag_chol
|
return diag_chol
|
||||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user