added dtype in observation spaces

This commit is contained in:
Maximilian Huettenrauch 2021-12-01 15:55:38 +01:00
parent ca90f257d4
commit f9e7a34eda
4 changed files with 7 additions and 7 deletions

View File

@ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
# self._tmp_hole_depth,
self.end_effector - self._goal,
self._steps
])
]).astype(np.float32)
def _get_line_points(self, num_points_per_link=1):
theta = self._joint_angles[:, None]

View File

@ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
self._angle_velocity,
self.end_effector - self._goal,
self._steps
])
]).astype(np.float32)
def _generate_goal(self):

View File

@ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
self.end_effector - self._via_point,
self.end_effector - self._goal,
self._steps
])
]).astype(np.float32)
def _check_collisions(self) -> bool:
return self._check_self_collision()

View File

@ -15,10 +15,10 @@ def _spec_to_box(spec):
assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found"
dim = int(np.prod(s.shape))
if type(s) == specs.Array:
bound = np.inf * np.ones(dim, dtype=np.float32)
bound = np.inf * np.ones(dim, dtype=s.dtype)
return -bound, bound
elif type(s) == specs.BoundedArray:
zeros = np.zeros(dim, dtype=np.float32)
zeros = np.zeros(dim, dtype=s.dtype)
return s.minimum + zeros, s.maximum + zeros
mins, maxs = [], []
@ -29,7 +29,7 @@ def _spec_to_box(spec):
low = np.concatenate(mins, axis=0)
high = np.concatenate(maxs, axis=0)
assert low.shape == high.shape
return spaces.Box(low, high, dtype=np.float32)
return spaces.Box(low, high, dtype=s.dtype)
def _flatten_obs(obs: collections.MutableMapping):
@ -113,7 +113,7 @@ class DMCWrapper(core.Env):
if self._channels_first:
obs = obs.transpose(2, 0, 1).copy()
else:
obs = _flatten_obs(time_step.observation)
obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype)
return obs
@property