Working on SDC
This commit is contained in:
parent
4c4b12ee0e
commit
e4440428f8
@ -28,31 +28,39 @@ class Strength(Enum):
|
|||||||
DIAG = 2
|
DIAG = 2
|
||||||
FULL = 3
|
FULL = 3
|
||||||
|
|
||||||
# def __init__(self, num):
|
|
||||||
# self.num = num
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def foo(self):
|
|
||||||
# return self.num
|
|
||||||
|
|
||||||
|
|
||||||
class ParametrizationType(Enum):
|
class ParametrizationType(Enum):
|
||||||
|
# Currently only Chol is implemented
|
||||||
CHOL = 1
|
CHOL = 1
|
||||||
SPHERICAL_CHOL = 2
|
#SPHERICAL_CHOL = 2
|
||||||
GIVENS = 3
|
#GIVENS = 3
|
||||||
|
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
SOFTPLUS = 1
|
# TODO: Allow custom params for softplus?
|
||||||
ABS = 2
|
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
||||||
RELU = 3
|
ABS = (2, th.abs)
|
||||||
SELU = 4
|
RELU = (3, nn.ReLU(inplace=False))
|
||||||
LOG = 5
|
LOG = (4, th.log)
|
||||||
|
|
||||||
|
def __init__(self, value, func):
|
||||||
|
self.value = value
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
def apply(self, x):
|
||||||
|
return self._func(x)
|
||||||
|
|
||||||
|
|
||||||
class ProbSquashingType(Enum):
|
class ProbSquashingType(Enum):
|
||||||
NONE = 0
|
NONE = (0, nn.Identity())
|
||||||
TANH = 1
|
TANH = (1, th.tanh)
|
||||||
|
|
||||||
|
def __init__(self, value, func):
|
||||||
|
self.value = value
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
def apply(self, x):
|
||||||
|
return self._func(x)
|
||||||
|
|
||||||
|
|
||||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||||
@ -118,10 +126,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
|
|
||||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||||
:param log_std_init: Initial value for the log standard deviation
|
:param log_std_init: Initial value for the log standard deviation
|
||||||
:return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix.
|
:return: We return two nn.Modules (mean, chol).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Rename pseudo_cov to pseudo_chol
|
# TODO: Allow chol to be vector when only diagonal.
|
||||||
|
|
||||||
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
mean_actions = nn.Linear(latent_dim, self.action_dim)
|
||||||
|
|
||||||
@ -139,33 +147,33 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
# TODO: Off-axis init?
|
# TODO: Off-axis init?
|
||||||
pseudo_cov_par = nn.Parameter(
|
pseudo_cov_par = nn.Parameter(
|
||||||
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
|
||||||
pseudo_cov = FakeModule(pseudo_cov_par)
|
chol = FakeModule(pseudo_cov_par)
|
||||||
elif self.par_strength == self.cov_strength:
|
elif self.par_strength == self.cov_strength:
|
||||||
if self.par_strength == Strength.NONE:
|
if self.par_strength == Strength.NONE:
|
||||||
pseudo_cov = FakeModule(th.ones(self.action_dim))
|
chol = FakeModule(th.ones(self.action_dim))
|
||||||
elif self.par_strength == Strength.SCALAR:
|
elif self.par_strength == Strength.SCALAR:
|
||||||
# TODO: Does it work like this? Test!
|
# TODO: Does it work like this? Test!
|
||||||
std = nn.Linear(latent_dim, 1)
|
std = nn.Linear(latent_dim, 1)
|
||||||
pseudo_cov = th.ones(self.action_dim) * std
|
chol = th.ones(self.action_dim) * std
|
||||||
elif self.par_strength == Strength.DIAG:
|
elif self.par_strength == Strength.DIAG:
|
||||||
pseudo_cov = nn.Linear(latent_dim, self.action_dim)
|
chol = nn.Linear(latent_dim, self.action_dim)
|
||||||
elif self.par_strength == Strength.FULL:
|
elif self.par_strength == Strength.FULL:
|
||||||
pseudo_cov = self._parameterize_full(latent_dim)
|
chol = self._parameterize_full(latent_dim)
|
||||||
elif self.par_strength > self.cov_strength:
|
elif self.par_strength > self.cov_strength:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'The parameterization can not be stronger than the actual covariance.')
|
'The parameterization can not be stronger than the actual covariance.')
|
||||||
else:
|
else:
|
||||||
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
|
||||||
pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim)
|
chol = self._parameterize_hybrid_from_scalar(latent_dim)
|
||||||
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
|
||||||
pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim)
|
chol = self._parameterize_hybrid_from_diag(latent_dim)
|
||||||
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'That does not even make any sense...')
|
'That does not even make any sense...')
|
||||||
else:
|
else:
|
||||||
raise Exception("This Exception can't happen (I think)")
|
raise Exception("This Exception can't happen (I think)")
|
||||||
|
|
||||||
return mean_actions, pseudo_cov
|
return mean_actions, chol
|
||||||
|
|
||||||
def _parameterize_full(self, latent_dim):
|
def _parameterize_full(self, latent_dim):
|
||||||
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
# TODO: Implement various techniques for full parameterization (forcing SPD)
|
||||||
@ -177,6 +185,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
'Programmer-was-to-lazy-to-implement-this-Exception')
|
'Programmer-was-to-lazy-to-implement-this-Exception')
|
||||||
|
|
||||||
|
def _ensure_positive_func(self, x):
|
||||||
|
return self.enforce_positive_type.apply(x)
|
||||||
|
|
||||||
|
def _ensure_diagonal_positive(self, pseudo_chol):
|
||||||
|
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
|
||||||
|
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
|
||||||
|
|
||||||
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
def _parameterize_hybrid_from_scalar(self, latent_dim):
|
||||||
factor = nn.Linear(latent_dim, 1)
|
factor = nn.Linear(latent_dim, 1)
|
||||||
par_cov = th.ones(self.action_dim) * \
|
par_cov = th.ones(self.action_dim) * \
|
||||||
|
Loading…
Reference in New Issue
Block a user