added dtype in observation spaces

This commit is contained in:
Maximilian Huettenrauch 2021-12-01 15:55:38 +01:00
parent ca90f257d4
commit f9e7a34eda
4 changed files with 7 additions and 7 deletions

View File

@ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
# self._tmp_hole_depth, # self._tmp_hole_depth,
self.end_effector - self._goal, self.end_effector - self._goal,
self._steps self._steps
]) ]).astype(np.float32)
def _get_line_points(self, num_points_per_link=1): def _get_line_points(self, num_points_per_link=1):
theta = self._joint_angles[:, None] theta = self._joint_angles[:, None]

View File

@ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
self._angle_velocity, self._angle_velocity,
self.end_effector - self._goal, self.end_effector - self._goal,
self._steps self._steps
]) ]).astype(np.float32)
def _generate_goal(self): def _generate_goal(self):

View File

@ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
self.end_effector - self._via_point, self.end_effector - self._via_point,
self.end_effector - self._goal, self.end_effector - self._goal,
self._steps self._steps
]) ]).astype(np.float32)
def _check_collisions(self) -> bool: def _check_collisions(self) -> bool:
return self._check_self_collision() return self._check_self_collision()

View File

@ -15,10 +15,10 @@ def _spec_to_box(spec):
assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found" assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found"
dim = int(np.prod(s.shape)) dim = int(np.prod(s.shape))
if type(s) == specs.Array: if type(s) == specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32) bound = np.inf * np.ones(dim, dtype=s.dtype)
return -bound, bound return -bound, bound
elif type(s) == specs.BoundedArray: elif type(s) == specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32) zeros = np.zeros(dim, dtype=s.dtype)
return s.minimum + zeros, s.maximum + zeros return s.minimum + zeros, s.maximum + zeros
mins, maxs = [], [] mins, maxs = [], []
@ -29,7 +29,7 @@ def _spec_to_box(spec):
low = np.concatenate(mins, axis=0) low = np.concatenate(mins, axis=0)
high = np.concatenate(maxs, axis=0) high = np.concatenate(maxs, axis=0)
assert low.shape == high.shape assert low.shape == high.shape
return spaces.Box(low, high, dtype=np.float32) return spaces.Box(low, high, dtype=s.dtype)
def _flatten_obs(obs: collections.MutableMapping): def _flatten_obs(obs: collections.MutableMapping):
@ -113,7 +113,7 @@ class DMCWrapper(core.Env):
if self._channels_first: if self._channels_first:
obs = obs.transpose(2, 0, 1).copy() obs = obs.transpose(2, 0, 1).copy()
else: else:
obs = _flatten_obs(time_step.observation) obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype)
return obs return obs
@property @property