Fixed smol bugs when instanciating based on string-names of enums
This commit is contained in:
parent
e074294b88
commit
2c1689fbbc
@ -116,10 +116,12 @@ def make_proba_distribution(
|
||||
if dist_kwargs is None:
|
||||
dist_kwargs = {}
|
||||
|
||||
dist_kwargs['use_sde'] = use_sde
|
||||
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
assert len(
|
||||
action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs)
|
||||
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
@ -151,7 +153,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
self.enforce_positive_type = cast_to_enum(
|
||||
enforce_positive_type, EnforcePositiveType)
|
||||
self.prob_squashing_type = cast_to_enum(
|
||||
prob_squashing_type, EnforcePositiveType)
|
||||
prob_squashing_type, ProbSquashingType)
|
||||
|
||||
self.epsilon = epsilon
|
||||
|
||||
@ -161,8 +163,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
if use_sde:
|
||||
raise Exception('SDE is not yet implemented')
|
||||
|
||||
assert (parameterization_type != ParametrizationType.NONE) == (
|
||||
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||
assert (self.par_type != ParametrizationType.NONE) == (
|
||||
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||
|
||||
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
||||
raise Exception(
|
||||
|
Loading…
Reference in New Issue
Block a user