diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 93fccac..9752560 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -72,6 +72,13 @@ class ProbSquashingType(Enum): return [nn.Identity(), TanhBijector.inverse][self.value](x) +def cast_to_enum(inp, Class): + if isinstance(inp, Enum): + return inp + else: + return Class[inp] + + def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): allowedEPTs = allowedEPTs or EnforcePositiveType allowedParStrength = allowedParStrength or Strength @@ -137,11 +144,14 @@ class UniversalGaussianDistribution(SB3_Distribution): def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim - self.par_strength = neural_strength - self.cov_strength = cov_strength - self.par_type = parameterization_type - self.enforce_positive_type = enforce_positive_type - self.prob_squashing_type = prob_squashing_type + self.par_strength = cast_to_enum(neural_strength, Strength) + self.cov_strength = cast_to_enum(cov_strength, Strength) + self.par_type = cast_to_enum( + parameterization_type, ParametrizationType) + self.enforce_positive_type = cast_to_enum( + enforce_positive_type, EnforcePositiveType) + self.prob_squashing_type = cast_to_enum( + prob_squashing_type, EnforcePositiveType) self.epsilon = epsilon