Automatic casting to enums
This commit is contained in:
parent
683644f77d
commit
8b82347056
@ -72,6 +72,13 @@ class ProbSquashingType(Enum):
|
||||
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
||||
|
||||
|
||||
def cast_to_enum(inp, Class):
|
||||
if isinstance(inp, Enum):
|
||||
return inp
|
||||
else:
|
||||
return Class[inp]
|
||||
|
||||
|
||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||
allowedParStrength = allowedParStrength or Strength
|
||||
@ -137,11 +144,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
|
||||
super(UniversalGaussianDistribution, self).__init__()
|
||||
self.action_dim = action_dim
|
||||
self.par_strength = neural_strength
|
||||
self.cov_strength = cov_strength
|
||||
self.par_type = parameterization_type
|
||||
self.enforce_positive_type = enforce_positive_type
|
||||
self.prob_squashing_type = prob_squashing_type
|
||||
self.par_strength = cast_to_enum(neural_strength, Strength)
|
||||
self.cov_strength = cast_to_enum(cov_strength, Strength)
|
||||
self.par_type = cast_to_enum(
|
||||
parameterization_type, ParametrizationType)
|
||||
self.enforce_positive_type = cast_to_enum(
|
||||
enforce_positive_type, EnforcePositiveType)
|
||||
self.prob_squashing_type = cast_to_enum(
|
||||
prob_squashing_type, EnforcePositiveType)
|
||||
|
||||
self.epsilon = epsilon
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user