Wrote a little helper-function to generate all allowed combinations of
cov-parameterizations
This commit is contained in:
parent
e09950b30c
commit
249754ee89
@ -30,25 +30,50 @@ class Strength(Enum):
|
|||||||
DIAG = 2
|
DIAG = 2
|
||||||
FULL = 3
|
FULL = 3
|
||||||
|
|
||||||
def __init__(self, num):
|
# def __init__(self, num):
|
||||||
self.num = num
|
# self.num = num
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def foo(self):
|
# def foo(self):
|
||||||
return self.num
|
# return self.num
|
||||||
|
|
||||||
|
|
||||||
class ParametrizationType(Enum):
|
class ParametrizationType(Enum):
|
||||||
CHOL = 1
|
CHOL = 1
|
||||||
ARCHAKOVA = 2
|
SPHERICAL_CHOL = 2
|
||||||
|
GIVENS = 3
|
||||||
|
|
||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
LOG = 1
|
SOFTPLUS = 1
|
||||||
RELU = 2
|
ABS = 2
|
||||||
SELU = 3
|
RELU = 3
|
||||||
ABS = 4
|
SELU = 4
|
||||||
SQ = 5
|
LOG = 5
|
||||||
|
|
||||||
|
|
||||||
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None):
|
||||||
|
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||||
|
allowedParStrength = allowedParStrength or Strength
|
||||||
|
allowedCovStrength = allowedCovStrength or Strength
|
||||||
|
allowedPTs = allowedPTs or ParametrizationType
|
||||||
|
|
||||||
|
for ps in allowedParStrength:
|
||||||
|
for cs in allowedCovStrength:
|
||||||
|
if ps.value > cs.value:
|
||||||
|
continue
|
||||||
|
if ps == Strength.SCALAR and cs == Strength.FULL:
|
||||||
|
# TODO: Maybe allow?
|
||||||
|
continue
|
||||||
|
if ps == Strength.NONE:
|
||||||
|
yield (ps, cs, None, None)
|
||||||
|
else:
|
||||||
|
for ept in allowedEPTs:
|
||||||
|
if cs == Strength.FULL:
|
||||||
|
for pt in allowedPTs:
|
||||||
|
yield (ps, cs, ept, pt)
|
||||||
|
else:
|
||||||
|
yield (ps, cs, ept, None)
|
||||||
|
|
||||||
|
|
||||||
class UniversalGaussianDistribution(SB3_Distribution):
|
class UniversalGaussianDistribution(SB3_Distribution):
|
||||||
|
Loading…
Reference in New Issue
Block a user