Making UniversalGaussianDistribution ready for tanh-squashing-support

This commit is contained in:
Dominik Moritz Roth 2022-07-09 14:33:07 +02:00
parent 249754ee89
commit c08ea1cb91

View File

@ -13,15 +13,12 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
from ..misc.fakeModule import FakeModule
# TODO: Full Cov Parameter
# TODO: Integrate and Test what I currently have before adding more complexity
# TODO: Support Squashed Dists (tanh)
# TODO: Contextual Cov
# TODO: - Scalar
# TODO: - Diag
# TODO: - Full
# TODO: - Hybrid
# TODO: Contextual SDE (Scalar + Diag + Full)
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
# TODO: (Support Squased Dists (tanh))
class Strength(Enum):
@ -52,11 +49,17 @@ class EnforcePositiveType(Enum):
LOG = 5
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None):
class ProbSquashingType(Enum):
NONE = 0
TANH = 1
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
allowedEPTs = allowedEPTs or EnforcePositiveType
allowedParStrength = allowedParStrength or Strength
allowedCovStrength = allowedCovStrength or Strength
allowedPTs = allowedPTs or ParametrizationType
allowedPSTs = allowedPSTs or ProbSquashingType
for ps in allowedParStrength:
for cs in allowedCovStrength:
@ -89,6 +92,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.cov_strength = Strength.DIAG
self.par_type = ParametrizationType.CHOL
self.enforce_positive_type = EnforcePositiveType.LOG
self.prob_squashing_type = ProbSquashingType.TANH
self.distribution = None