Automatic casting to enums
This commit is contained in:
parent
683644f77d
commit
8b82347056
@ -72,6 +72,13 @@ class ProbSquashingType(Enum):
|
|||||||
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
return [nn.Identity(), TanhBijector.inverse][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_enum(inp, Class):
|
||||||
|
if isinstance(inp, Enum):
|
||||||
|
return inp
|
||||||
|
else:
|
||||||
|
return Class[inp]
|
||||||
|
|
||||||
|
|
||||||
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None):
|
||||||
allowedEPTs = allowedEPTs or EnforcePositiveType
|
allowedEPTs = allowedEPTs or EnforcePositiveType
|
||||||
allowedParStrength = allowedParStrength or Strength
|
allowedParStrength = allowedParStrength or Strength
|
||||||
@ -137,11 +144,14 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE, epsilon=1e-6):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.par_strength = neural_strength
|
self.par_strength = cast_to_enum(neural_strength, Strength)
|
||||||
self.cov_strength = cov_strength
|
self.cov_strength = cast_to_enum(cov_strength, Strength)
|
||||||
self.par_type = parameterization_type
|
self.par_type = cast_to_enum(
|
||||||
self.enforce_positive_type = enforce_positive_type
|
parameterization_type, ParametrizationType)
|
||||||
self.prob_squashing_type = prob_squashing_type
|
self.enforce_positive_type = cast_to_enum(
|
||||||
|
enforce_positive_type, EnforcePositiveType)
|
||||||
|
self.prob_squashing_type = cast_to_enum(
|
||||||
|
prob_squashing_type, EnforcePositiveType)
|
||||||
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user