Working on SDC

This commit is contained in:
Dominik Moritz Roth 2022-07-11 11:55:23 +02:00
parent 4c4b12ee0e
commit e4440428f8

View File

@ -28,31 +28,39 @@ class Strength(Enum):
DIAG = 2
FULL = 3
# def __init__(self, num):
# self.num = num
# @property
# def foo(self):
# return self.num
class ParametrizationType(Enum):
# Currently only Chol is implemented
CHOL = 1
SPHERICAL_CHOL = 2
GIVENS = 3
#SPHERICAL_CHOL = 2
#GIVENS = 3
class EnforcePositiveType(Enum):
SOFTPLUS = 1
ABS = 2
RELU = 3
SELU = 4
LOG = 5
# TODO: Allow custom params for softplus?
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
ABS = (2, th.abs)
RELU = (3, nn.ReLU(inplace=False))
LOG = (4, th.log)
def __init__(self, value, func):
self.value = value
self._func = func
def apply(self, x):
return self._func(x)
class ProbSquashingType(Enum):
NONE = 0
TANH = 1
NONE = (0, nn.Identity())
TANH = (1, th.tanh)
def __init__(self, value, func):
self.value = value
self._func = func
def apply(self, x):
return self._func(x)
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
@ -118,10 +126,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
:param log_std_init: Initial value for the log standard deviation
:return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix.
:return: We return two nn.Modules (mean, chol).
"""
# TODO: Rename pseudo_cov to pseudo_chol
# TODO: Allow chol to be vector when only diagonal.
mean_actions = nn.Linear(latent_dim, self.action_dim)
@ -139,33 +147,33 @@ class UniversalGaussianDistribution(SB3_Distribution):
# TODO: Off-axis init?
pseudo_cov_par = nn.Parameter(
th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True)
pseudo_cov = FakeModule(pseudo_cov_par)
chol = FakeModule(pseudo_cov_par)
elif self.par_strength == self.cov_strength:
if self.par_strength == Strength.NONE:
pseudo_cov = FakeModule(th.ones(self.action_dim))
chol = FakeModule(th.ones(self.action_dim))
elif self.par_strength == Strength.SCALAR:
# TODO: Does it work like this? Test!
std = nn.Linear(latent_dim, 1)
pseudo_cov = th.ones(self.action_dim) * std
chol = th.ones(self.action_dim) * std
elif self.par_strength == Strength.DIAG:
pseudo_cov = nn.Linear(latent_dim, self.action_dim)
chol = nn.Linear(latent_dim, self.action_dim)
elif self.par_strength == Strength.FULL:
pseudo_cov = self._parameterize_full(latent_dim)
chol = self._parameterize_full(latent_dim)
elif self.par_strength > self.cov_strength:
raise Exception(
'The parameterization can not be stronger than the actual covariance.')
else:
if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG:
pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim)
chol = self._parameterize_hybrid_from_scalar(latent_dim)
elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim)
chol = self._parameterize_hybrid_from_diag(latent_dim)
elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
raise Exception(
'That does not even make any sense...')
else:
raise Exception("This Exception can't happen (I think)")
return mean_actions, pseudo_cov
return mean_actions, chol
def _parameterize_full(self, latent_dim):
# TODO: Implement various techniques for full parameterization (forcing SPD)
@ -177,6 +185,13 @@ class UniversalGaussianDistribution(SB3_Distribution):
raise Exception(
'Programmer-was-to-lazy-to-implement-this-Exception')
def _ensure_positive_func(self, x):
return self.enforce_positive_type.apply(x)
def _ensure_diagonal_positive(self, pseudo_chol):
pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2,
dim2=-1)).diag_embed() + pseudo_chol.triu(1)
def _parameterize_hybrid_from_scalar(self, latent_dim):
factor = nn.Linear(latent_dim, 1)
par_cov = th.ones(self.action_dim) * \