added dtype in observation spaces
This commit is contained in:
		
							parent
							
								
									ca90f257d4
								
							
						
					
					
						commit
						f9e7a34eda
					
				| @ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): | ||||
|             # self._tmp_hole_depth, | ||||
|             self.end_effector - self._goal, | ||||
|             self._steps | ||||
|         ]) | ||||
|         ]).astype(np.float32) | ||||
| 
 | ||||
|     def _get_line_points(self, num_points_per_link=1): | ||||
|         theta = self._joint_angles[:, None] | ||||
|  | ||||
| @ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): | ||||
|             self._angle_velocity, | ||||
|             self.end_effector - self._goal, | ||||
|             self._steps | ||||
|         ]) | ||||
|         ]).astype(np.float32) | ||||
| 
 | ||||
|     def _generate_goal(self): | ||||
| 
 | ||||
|  | ||||
| @ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): | ||||
|             self.end_effector - self._via_point, | ||||
|             self.end_effector - self._goal, | ||||
|             self._steps | ||||
|         ]) | ||||
|         ]).astype(np.float32) | ||||
| 
 | ||||
|     def _check_collisions(self) -> bool: | ||||
|         return self._check_self_collision() | ||||
|  | ||||
| @ -15,10 +15,10 @@ def _spec_to_box(spec): | ||||
|         assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found" | ||||
|         dim = int(np.prod(s.shape)) | ||||
|         if type(s) == specs.Array: | ||||
|             bound = np.inf * np.ones(dim, dtype=np.float32) | ||||
|             bound = np.inf * np.ones(dim, dtype=s.dtype) | ||||
|             return -bound, bound | ||||
|         elif type(s) == specs.BoundedArray: | ||||
|             zeros = np.zeros(dim, dtype=np.float32) | ||||
|             zeros = np.zeros(dim, dtype=s.dtype) | ||||
|             return s.minimum + zeros, s.maximum + zeros | ||||
| 
 | ||||
|     mins, maxs = [], [] | ||||
| @ -29,7 +29,7 @@ def _spec_to_box(spec): | ||||
|     low = np.concatenate(mins, axis=0) | ||||
|     high = np.concatenate(maxs, axis=0) | ||||
|     assert low.shape == high.shape | ||||
|     return spaces.Box(low, high, dtype=np.float32) | ||||
|     return spaces.Box(low, high, dtype=s.dtype) | ||||
| 
 | ||||
| 
 | ||||
| def _flatten_obs(obs: collections.MutableMapping): | ||||
| @ -113,7 +113,7 @@ class DMCWrapper(core.Env): | ||||
|             if self._channels_first: | ||||
|                 obs = obs.transpose(2, 0, 1).copy() | ||||
|         else: | ||||
|             obs = _flatten_obs(time_step.observation) | ||||
|             obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype) | ||||
|         return obs | ||||
| 
 | ||||
|     @property | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user