Fixed smol bugs when instanciating based on string-names of enums
This commit is contained in:
parent
e074294b88
commit
2c1689fbbc
@ -116,10 +116,12 @@ def make_proba_distribution(
|
|||||||
if dist_kwargs is None:
|
if dist_kwargs is None:
|
||||||
dist_kwargs = {}
|
dist_kwargs = {}
|
||||||
|
|
||||||
|
dist_kwargs['use_sde'] = use_sde
|
||||||
|
|
||||||
if isinstance(action_space, gym.spaces.Box):
|
if isinstance(action_space, gym.spaces.Box):
|
||||||
assert len(
|
assert len(
|
||||||
action_space.shape) == 1, "Error: the action space must be a vector"
|
action_space.shape) == 1, "Error: the action space must be a vector"
|
||||||
return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs)
|
return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||||
elif isinstance(action_space, gym.spaces.Discrete):
|
elif isinstance(action_space, gym.spaces.Discrete):
|
||||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||||
@ -151,7 +153,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
self.enforce_positive_type = cast_to_enum(
|
self.enforce_positive_type = cast_to_enum(
|
||||||
enforce_positive_type, EnforcePositiveType)
|
enforce_positive_type, EnforcePositiveType)
|
||||||
self.prob_squashing_type = cast_to_enum(
|
self.prob_squashing_type = cast_to_enum(
|
||||||
prob_squashing_type, EnforcePositiveType)
|
prob_squashing_type, ProbSquashingType)
|
||||||
|
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
|
||||||
@ -161,8 +163,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
if use_sde:
|
if use_sde:
|
||||||
raise Exception('SDE is not yet implemented')
|
raise Exception('SDE is not yet implemented')
|
||||||
|
|
||||||
assert (parameterization_type != ParametrizationType.NONE) == (
|
assert (self.par_type != ParametrizationType.NONE) == (
|
||||||
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
|
|
||||||
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
Loading…
Reference in New Issue
Block a user