Working on SDC
This commit is contained in:
parent
4c4b12ee0e
commit
e4440428f8
@ -28,31 +28,39 @@ class Strength(Enum):
|
||||
DIAG = 2
|
||||
FULL = 3
|
||||
|
||||
# def __init__(self, num):
|
||||
# self.num = num
|
||||
|
||||
# @property
|
||||
# def foo(self):
|
||||
# return self.num
|
||||
|
||||
|
||||
class ParametrizationType(Enum):
|
||||
# Currently only Chol is implemented
|
||||
CHOL = 1
|
||||
SPHERICAL_CHOL = 2
|
||||
GIVENS = 3
|
||||
#SPHERICAL_CHOL = 2
|
||||
#GIVENS = 3
|
||||
|
||||
|
||||
class EnforcePositiveType(Enum):
|
||||
SOFTPLUS = 1
|
||||
ABS = 2
|
||||
RELU = 3
|
||||
SELU = 4
|
||||
LOG = 5
|
||||
# TODO: Allow custom params for softplus?
|
||||
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
||||
ABS = (2, th.abs)
|
||||
RELU = (3, nn.ReLU(inplace=False))
|
||||
LOG = (4, th.log)
|
||||
|
||||
def __init__(self, value, func):
|
||||
self.value = value
|
||||
self._func = func
|
||||
|
||||
def apply(self, x):
|
||||
return self._func(x)
|
||||
|
||||
|
||||
class ProbSquashingType(Enum):
|
||||
NONE = 0
|
||||
TANH = 1
|
||||
NONE = (0, nn.Identity())
|
||||
TANH = (1, th.tanh)
|
||||
|
||||
def __init__(self, value, func):
|
||||
self.value = value
|
||||
self._func = func
|
||||
|
||||
def apply(self, x):
|
||||
return self._func(x)
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||
@ -118,10 +126,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix.
|
||||
:return: We return two nn.Modules (mean, chol).
|
||||
"""
|
||||
|
||||
# TODO: Rename pseudo_cov to pseudo_chol
|
||||
# TODO: Allow chol to be vector when only diagonal.
|
||||
|
||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||
|
||||
@ -139,33 +147,33 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
# TODO: Off-axis init?
|
||||
pseudo_cov_par = nn.Parameter(
|
||||
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
||||
pseudo_cov = FakeModule(pseudo_cov_par)
|
||||
chol = FakeModule(pseudo_cov_par)
|
||||
elif self.par_strength == self.cov_strength:
|
||||
if self.par_strength == Strength.NONE:
|
||||
pseudo_cov = FakeModule(th.ones(self.action_dim))
|
||||
chol = FakeModule(th.ones(self.action_dim))
|
||||
elif self.par_strength == Strength.SCALAR:
|
||||
# TODO: Does it work like this? Test!
|
||||
std = nn.Linear(latent_dim, 1)
|
||||
pseudo_cov = th.ones(self.action_dim) * std
|
||||
chol = th.ones(self.action_dim) * std
|
||||
elif self.par_strength == Strength.DIAG:
|
||||
pseudo_cov = nn.Linear(latent_dim, self.action_dim)
|
||||
chol = nn.Linear(latent_dim, self.action_dim)
|
||||
elif self.par_strength == Strength.FULL:
|
||||
pseudo_cov = self._parameterize_full(latent_dim)
|
||||
chol = self._parameterize_full(latent_dim)
|
||||
elif self.par_strength > self.cov_strength:
|
||||
raise Exception(
|
||||
'The parameterization can not be stronger than the actual covariance.')
|
||||
else:
|
||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||
pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim)
|
||||
chol = self._parameterize_hybrid_from_scalar(latent_dim)
|
||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||
pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim)
|
||||
chol = self._parameterize_hybrid_from_diag(latent_dim)
|
||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||
raise Exception(
|
||||
'That does not even make any sense...')
|
||||
else:
|
||||
raise Exception("This Exception can't happen (I think)")
|
||||
|
||||
return mean_actions, pseudo_cov
|
||||
return mean_actions, chol
|
||||
|
||||
def _parameterize_full(self, latent_dim):
|
||||
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
||||
@ -177,6 +185,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
raise Exception(
|
||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||
|
||||
def _ensure_positive_func(self, x):
|
||||
return self.enforce_positive_type.apply(x)
|
||||
|
||||
def _ensure_diagonal_positive(self, pseudo_chol):
|
||||
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
|
||||
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
|
||||
|
||||
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||
factor = nn.Linear(latent_dim, 1)
|
||||
par_cov = th.ones(self.action_dim) * \
|
||||
|
Loading…
Reference in New Issue
Block a user