Wrote a little helper-function to generate all allowed combinations of

cov-parameterizations
This commit is contained in:
Dominik Moritz Roth 2022-07-09 14:03:56 +02:00
parent e09950b30c
commit 249754ee89

View File

@ -30,25 +30,50 @@ class Strength(Enum):
DIAG = 2 DIAG = 2
FULL = 3 FULL = 3
def __init__(self, num): # def __init__(self, num):
self.num = num # self.num = num
@property # @property
def foo(self): # def foo(self):
return self.num # return self.num
class ParametrizationType(Enum): class ParametrizationType(Enum):
CHOL = 1 CHOL = 1
ARCHAKOVA = 2 SPHERICAL_CHOL = 2
GIVENS = 3
class EnforcePositiveType(Enum): class EnforcePositiveType(Enum):
LOG = 1 SOFTPLUS = 1
RELU = 2 ABS = 2
SELU = 3 RELU = 3
ABS = 4 SELU = 4
SQ = 5 LOG = 5
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None):
allowedEPTs = allowedEPTs or EnforcePositiveType
allowedParStrength = allowedParStrength or Strength
allowedCovStrength = allowedCovStrength or Strength
allowedPTs = allowedPTs or ParametrizationType
for ps in allowedParStrength:
for cs in allowedCovStrength:
if ps.value > cs.value:
continue
if ps == Strength.SCALAR and cs == Strength.FULL:
# TODO: Maybe allow?
continue
if ps == Strength.NONE:
yield (ps, cs, None, None)
else:
for ept in allowedEPTs:
if cs == Strength.FULL:
for pt in allowedPTs:
yield (ps, cs, ept, pt)
else:
yield (ps, cs, ept, None)
class UniversalGaussianDistribution(SB3_Distribution): class UniversalGaussianDistribution(SB3_Distribution):