Fixed smol bugs when instanciating based on string-names of enums

This commit is contained in:
Dominik Moritz Roth 2022-08-06 14:36:35 +02:00
parent e074294b88
commit 2c1689fbbc

View File

@ -116,10 +116,12 @@ def make_proba_distribution(
if dist_kwargs is None: if dist_kwargs is None:
dist_kwargs = {} dist_kwargs = {}
dist_kwargs['use_sde'] = use_sde
if isinstance(action_space, gym.spaces.Box): if isinstance(action_space, gym.spaces.Box):
assert len( assert len(
action_space.shape) == 1, "Error: the action space must be a vector" action_space.shape) == 1, "Error: the action space must be a vector"
return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs) return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, gym.spaces.Discrete): elif isinstance(action_space, gym.spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs) return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, gym.spaces.MultiDiscrete): elif isinstance(action_space, gym.spaces.MultiDiscrete):
@ -151,7 +153,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
self.enforce_positive_type = cast_to_enum( self.enforce_positive_type = cast_to_enum(
enforce_positive_type, EnforcePositiveType) enforce_positive_type, EnforcePositiveType)
self.prob_squashing_type = cast_to_enum( self.prob_squashing_type = cast_to_enum(
prob_squashing_type, EnforcePositiveType) prob_squashing_type, ProbSquashingType)
self.epsilon = epsilon self.epsilon = epsilon
@ -161,8 +163,8 @@ class UniversalGaussianDistribution(SB3_Distribution):
if use_sde: if use_sde:
raise Exception('SDE is not yet implemented') raise Exception('SDE is not yet implemented')
assert (parameterization_type != ParametrizationType.NONE) == ( assert (self.par_type != ParametrizationType.NONE) == (
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE: if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE:
raise Exception( raise Exception(