From 249754ee890be8d03d04c75af399628c2195c6a5 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 9 Jul 2022 14:03:56 +0200 Subject: [PATCH] Wrote a little helper-function to generate all allowed combinations of cov-parameterizations --- .../distributions/distributions.py | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 05860c0..d2f6804 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -30,25 +30,50 @@ class Strength(Enum): DIAG = 2 FULL = 3 - def __init__(self, num): - self.num = num + # def __init__(self, num): + # self.num = num - @property - def foo(self): - return self.num + # @property + # def foo(self): + # return self.num class ParametrizationType(Enum): CHOL = 1 - ARCHAKOVA = 2 + SPHERICAL_CHOL = 2 + GIVENS = 3 class EnforcePositiveType(Enum): - LOG = 1 - RELU = 2 - SELU = 3 - ABS = 4 - SQ = 5 + SOFTPLUS = 1 + ABS = 2 + RELU = 3 + SELU = 4 + LOG = 5 + + +def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None): + allowedEPTs = allowedEPTs or EnforcePositiveType + allowedParStrength = allowedParStrength or Strength + allowedCovStrength = allowedCovStrength or Strength + allowedPTs = allowedPTs or ParametrizationType + + for ps in allowedParStrength: + for cs in allowedCovStrength: + if ps.value > cs.value: + continue + if ps == Strength.SCALAR and cs == Strength.FULL: + # TODO: Maybe allow? + continue + if ps == Strength.NONE: + yield (ps, cs, None, None) + else: + for ept in allowedEPTs: + if cs == Strength.FULL: + for pt in allowedPTs: + yield (ps, cs, ept, pt) + else: + yield (ps, cs, ept, None) class UniversalGaussianDistribution(SB3_Distribution):