diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index d2f6804..8e54218 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -13,15 +13,12 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution from ..misc.fakeModule import FakeModule -# TODO: Full Cov Parameter +# TODO: Integrate and Test what I currently have before adding more complexity +# TODO: Support Squashed Dists (tanh) # TODO: Contextual Cov -# TODO: - Scalar -# TODO: - Diag -# TODO: - Full # TODO: - Hybrid # TODO: Contextual SDE (Scalar + Diag + Full) # TODO: (SqrtInducedCov (Scalar + Diag + Full)) -# TODO: (Support Squased Dists (tanh)) class Strength(Enum): @@ -52,11 +49,17 @@ class EnforcePositiveType(Enum): LOG = 5 -def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None): +class ProbSquashingType(Enum): + NONE = 0 + TANH = 1 + + +def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): allowedEPTs = allowedEPTs or EnforcePositiveType allowedParStrength = allowedParStrength or Strength allowedCovStrength = allowedCovStrength or Strength allowedPTs = allowedPTs or ParametrizationType + allowedPSTs = allowedPSTs or ProbSquashingType for ps in allowedParStrength: for cs in allowedCovStrength: @@ -89,6 +92,7 @@ class UniversalGaussianDistribution(SB3_Distribution): self.cov_strength = Strength.DIAG self.par_type = ParametrizationType.CHOL self.enforce_positive_type = EnforcePositiveType.LOG + self.prob_squashing_type = ProbSquashingType.TANH self.distribution = None