Making UniversalGaussianDistribution ready for tanh-squashing-support
This commit is contained in:
parent
249754ee89
commit
c08ea1cb91
@ -13,15 +13,12 @@ from stable_baselines3.common.distributions import DiagGaussianDistribution
|
||||
|
||||
from ..misc.fakeModule import FakeModule
|
||||
|
||||
# TODO: Full Cov Parameter
|
||||
# TODO: Integrate and Test what I currently have before adding more complexity
|
||||
# TODO: Support Squashed Dists (tanh)
|
||||
# TODO: Contextual Cov
|
||||
# TODO: - Scalar
|
||||
# TODO: - Diag
|
||||
# TODO: - Full
|
||||
# TODO: - Hybrid
|
||||
# TODO: Contextual SDE (Scalar + Diag + Full)
|
||||
# TODO: (SqrtInducedCov (Scalar + Diag + Full))
|
||||
# TODO: (Support Squased Dists (tanh))
|
||||
|
||||
|
||||
class Strength(Enum):
|
||||
@ -52,11 +49,17 @@ class EnforcePositiveType(Enum):
|
||||
LOG = 5
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None):
|
||||
class ProbSquashingType(Enum):
|
||||
NONE = 0
|
||||
TANH = 1
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||
allowedParStrength = allowedParStrength or Strength
|
||||
allowedCovStrength = allowedCovStrength or Strength
|
||||
allowedPTs = allowedPTs or ParametrizationType
|
||||
allowedPSTs = allowedPSTs or ProbSquashingType
|
||||
|
||||
for ps in allowedParStrength:
|
||||
for cs in allowedCovStrength:
|
||||
@ -89,6 +92,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
self.cov_strength = Strength.DIAG
|
||||
self.par_type = ParametrizationType.CHOL
|
||||
self.enforce_positive_type = EnforcePositiveType.LOG
|
||||
self.prob_squashing_type = ProbSquashingType.TANH
|
||||
|
||||
self.distribution = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user