Automatic casting to enums

This commit is contained in:
Dominik Moritz Roth 2022-08-05 21:06:31 +02:00
parent 683644f77d
commit 8b82347056

View File

@ -72,6 +72,13 @@ class ProbSquashingType(Enum):
return [nn.Identity(), TanhBijector.inverse][self.value](x) return [nn.Identity(), TanhBijector.inverse][self.value](x)
def cast_to_enum(inp, Class):
if isinstance(inp, Enum):
return inp
else:
return Class[inp]
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
allowedEPTs = allowedEPTs or EnforcePositiveType allowedEPTs = allowedEPTs or EnforcePositiveType
allowedParStrength = allowedParStrength or Strength allowedParStrength = allowedParStrength or Strength
@ -137,11 +144,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6): def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
super(UniversalGaussianDistribution, self).__init__() super(UniversalGaussianDistribution, self).__init__()
self.action_dim = action_dim self.action_dim = action_dim
self.par_strength = neural_strength self.par_strength = cast_to_enum(neural_strength, Strength)
self.cov_strength = cov_strength self.cov_strength = cast_to_enum(cov_strength, Strength)
self.par_type = parameterization_type self.par_type = cast_to_enum(
self.enforce_positive_type = enforce_positive_type parameterization_type, ParametrizationType)
self.prob_squashing_type = prob_squashing_type self.enforce_positive_type = cast_to_enum(
enforce_positive_type, EnforcePositiveType)
self.prob_squashing_type = cast_to_enum(
prob_squashing_type, EnforcePositiveType)
self.epsilon = epsilon self.epsilon = epsilon