From f9e7a34edae81833024727ac3224c764218698ac Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 1 Dec 2021 15:55:38 +0100 Subject: [PATCH] added dtype in observation spaces --- alr_envs/alr/classic_control/hole_reacher/hole_reacher.py | 2 +- .../alr/classic_control/simple_reacher/simple_reacher.py | 2 +- .../classic_control/viapoint_reacher/viapoint_reacher.py | 2 +- alr_envs/dmc/dmc_wrapper.py | 8 ++++---- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py index dd7321a..208f005 100644 --- a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py +++ b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py @@ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv): # self._tmp_hole_depth, self.end_effector - self._goal, self._steps - ]) + ]).astype(np.float32) def _get_line_points(self, num_points_per_link=1): theta = self._joint_angles[:, None] diff --git a/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py b/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py index 758f824..dac06a3 100644 --- a/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py +++ b/alr_envs/alr/classic_control/simple_reacher/simple_reacher.py @@ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): self._angle_velocity, self.end_effector - self._goal, self._steps - ]) + ]).astype(np.float32) def _generate_goal(self): diff --git a/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py b/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py index 748eb99..b44647e 100644 --- a/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/alr_envs/alr/classic_control/viapoint_reacher/viapoint_reacher.py @@ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv): self.end_effector - self._via_point, self.end_effector - self._goal, self._steps - ]) + ]).astype(np.float32) def _check_collisions(self) -> bool: return self._check_self_collision() diff --git a/alr_envs/dmc/dmc_wrapper.py b/alr_envs/dmc/dmc_wrapper.py index 10f1af9..aa6c7aa 100644 --- a/alr_envs/dmc/dmc_wrapper.py +++ b/alr_envs/dmc/dmc_wrapper.py @@ -15,10 +15,10 @@ def _spec_to_box(spec): assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found" dim = int(np.prod(s.shape)) if type(s) == specs.Array: - bound = np.inf * np.ones(dim, dtype=np.float32) + bound = np.inf * np.ones(dim, dtype=s.dtype) return -bound, bound elif type(s) == specs.BoundedArray: - zeros = np.zeros(dim, dtype=np.float32) + zeros = np.zeros(dim, dtype=s.dtype) return s.minimum + zeros, s.maximum + zeros mins, maxs = [], [] @@ -29,7 +29,7 @@ def _spec_to_box(spec): low = np.concatenate(mins, axis=0) high = np.concatenate(maxs, axis=0) assert low.shape == high.shape - return spaces.Box(low, high, dtype=np.float32) + return spaces.Box(low, high, dtype=s.dtype) def _flatten_obs(obs: collections.MutableMapping): @@ -113,7 +113,7 @@ class DMCWrapper(core.Env): if self._channels_first: obs = obs.transpose(2, 0, 1).copy() else: - obs = _flatten_obs(time_step.observation) + obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype) return obs @property