Fixed smol bugs when instanciating based on string-names of enums
This commit is contained in:
		
							parent
							
								
									e074294b88
								
							
						
					
					
						commit
						2c1689fbbc
					
				| @ -116,10 +116,12 @@ def make_proba_distribution( | ||||
|     if dist_kwargs is None: | ||||
|         dist_kwargs = {} | ||||
| 
 | ||||
|     dist_kwargs['use_sde'] = use_sde | ||||
| 
 | ||||
|     if isinstance(action_space, gym.spaces.Box): | ||||
|         assert len( | ||||
|             action_space.shape) == 1, "Error: the action space must be a vector" | ||||
|         return UniversalGaussianDistribution(get_action_dim(action_space), use_sde=use_sde, **dist_kwargs) | ||||
|         return UniversalGaussianDistribution(get_action_dim(action_space), **dist_kwargs) | ||||
|     elif isinstance(action_space, gym.spaces.Discrete): | ||||
|         return CategoricalDistribution(action_space.n, **dist_kwargs) | ||||
|     elif isinstance(action_space, gym.spaces.MultiDiscrete): | ||||
| @ -151,7 +153,7 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|         self.enforce_positive_type = cast_to_enum( | ||||
|             enforce_positive_type, EnforcePositiveType) | ||||
|         self.prob_squashing_type = cast_to_enum( | ||||
|             prob_squashing_type, EnforcePositiveType) | ||||
|             prob_squashing_type, ProbSquashingType) | ||||
| 
 | ||||
|         self.epsilon = epsilon | ||||
| 
 | ||||
| @ -161,8 +163,8 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|         if use_sde: | ||||
|             raise Exception('SDE is not yet implemented') | ||||
| 
 | ||||
|         assert (parameterization_type != ParametrizationType.NONE) == ( | ||||
|             cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' | ||||
|         assert (self.par_type != ParametrizationType.NONE) == ( | ||||
|             self.cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' | ||||
| 
 | ||||
|         if self.par_type == ParametrizationType.SPHERICAL_CHOL and self.enforce_positive_type == EnforcePositiveType.NONE: | ||||
|             raise Exception( | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user