Wrote a little helper-function to generate all allowed combinations of
cov-parameterizations
This commit is contained in:
parent
e09950b30c
commit
249754ee89
@ -30,25 +30,50 @@ class Strength(Enum):
|
||||
DIAG = 2
|
||||
FULL = 3
|
||||
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
# def __init__(self, num):
|
||||
# self.num = num
|
||||
|
||||
@property
|
||||
def foo(self):
|
||||
return self.num
|
||||
# @property
|
||||
# def foo(self):
|
||||
# return self.num
|
||||
|
||||
|
||||
class ParametrizationType(Enum):
|
||||
CHOL = 1
|
||||
ARCHAKOVA = 2
|
||||
SPHERICAL_CHOL = 2
|
||||
GIVENS = 3
|
||||
|
||||
|
||||
class EnforcePositiveType(Enum):
|
||||
LOG = 1
|
||||
RELU = 2
|
||||
SELU = 3
|
||||
ABS = 4
|
||||
SQ = 5
|
||||
SOFTPLUS = 1
|
||||
ABS = 2
|
||||
RELU = 3
|
||||
SELU = 4
|
||||
LOG = 5
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None):
|
||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||
allowedParStrength = allowedParStrength or Strength
|
||||
allowedCovStrength = allowedCovStrength or Strength
|
||||
allowedPTs = allowedPTs or ParametrizationType
|
||||
|
||||
for ps in allowedParStrength:
|
||||
for cs in allowedCovStrength:
|
||||
if ps.value > cs.value:
|
||||
continue
|
||||
if ps == Strength.SCALAR and cs == Strength.FULL:
|
||||
# TODO: Maybe allow?
|
||||
continue
|
||||
if ps == Strength.NONE:
|
||||
yield (ps, cs, None, None)
|
||||
else:
|
||||
for ept in allowedEPTs:
|
||||
if cs == Strength.FULL:
|
||||
for pt in allowedPTs:
|
||||
yield (ps, cs, ept, pt)
|
||||
else:
|
||||
yield (ps, cs, ept, None)
|
||||
|
||||
|
||||
class UniversalGaussianDistribution(SB3_Distribution):
|
||||
|
Loading…
Reference in New Issue
Block a user