Implement support for Dict spaces for time_aware_observation-wrapper
This commit is contained in:
		
							parent
							
								
									abeb963b4e
								
							
						
					
					
						commit
						a23b44752e
					
				| @ -2,6 +2,7 @@ from gymnasium.spaces import Box, Dict | |||||||
| from gym.spaces import Box as OldBox | from gym.spaces import Box as OldBox | ||||||
| import gymnasium as gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
|  | import copy | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): | class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): | ||||||
| @ -24,19 +25,21 @@ class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr | |||||||
|         if enforce_dtype_float32: |         if enforce_dtype_float32: | ||||||
|             assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( |             assert env.observation_space.dtype == np.float32, 'TimeAwareObservation was given an environment with a dtype!=np.float32 ('+str( | ||||||
|                 env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' |                 env.observation_space.dtype)+'). This requirement can be removed by setting enforce_dtype_float32=False.' | ||||||
|         dtype = env.observation_space.dtype |  | ||||||
| 
 |  | ||||||
|         assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' |         assert env.observation_space.__class__ in allowed_classes, str(env.observation_space)+' is not supported. Only Box or Dict' | ||||||
| 
 | 
 | ||||||
|         if env.observation_space.__class__ in [Box, OldBox]: |         if env.observation_space.__class__ in [Box, OldBox]: | ||||||
|  |             dtype = env.observation_space.dtype | ||||||
|  | 
 | ||||||
|             low = np.append(env.observation_space.low, 0.0) |             low = np.append(env.observation_space.low, 0.0) | ||||||
|             high = np.append(env.observation_space.high, 1.0) |             high = np.append(env.observation_space.high, 1.0) | ||||||
| 
 | 
 | ||||||
|             self.observation_space = Box(low, high, dtype=dtype) |             self.observation_space = Box(low, high, dtype=dtype) | ||||||
|         else: |         else: | ||||||
|             import pdb |             spaces = copy.copy(env.observation_space.spaces) | ||||||
|             pdb.set_trace() |             dtype = np.float64 | ||||||
|             exit |             spaces['time_awareness'] = Box(0, 1, dtype=dtype) | ||||||
|  | 
 | ||||||
|  |             self.observation_space = Dict(spaces) | ||||||
| 
 | 
 | ||||||
|         self.is_vector_env = getattr(env, "is_vector_env", False) |         self.is_vector_env = getattr(env, "is_vector_env", False) | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user