added dtype in observation spaces
This commit is contained in:
parent
ca90f257d4
commit
f9e7a34eda
@ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
# self._tmp_hole_depth,
|
# self._tmp_hole_depth,
|
||||||
self.end_effector - self._goal,
|
self.end_effector - self._goal,
|
||||||
self._steps
|
self._steps
|
||||||
])
|
]).astype(np.float32)
|
||||||
|
|
||||||
def _get_line_points(self, num_points_per_link=1):
|
def _get_line_points(self, num_points_per_link=1):
|
||||||
theta = self._joint_angles[:, None]
|
theta = self._joint_angles[:, None]
|
||||||
|
@ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
self._angle_velocity,
|
self._angle_velocity,
|
||||||
self.end_effector - self._goal,
|
self.end_effector - self._goal,
|
||||||
self._steps
|
self._steps
|
||||||
])
|
]).astype(np.float32)
|
||||||
|
|
||||||
def _generate_goal(self):
|
def _generate_goal(self):
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.end_effector - self._via_point,
|
self.end_effector - self._via_point,
|
||||||
self.end_effector - self._goal,
|
self.end_effector - self._goal,
|
||||||
self._steps
|
self._steps
|
||||||
])
|
]).astype(np.float32)
|
||||||
|
|
||||||
def _check_collisions(self) -> bool:
|
def _check_collisions(self) -> bool:
|
||||||
return self._check_self_collision()
|
return self._check_self_collision()
|
||||||
|
@ -15,10 +15,10 @@ def _spec_to_box(spec):
|
|||||||
assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found"
|
assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found"
|
||||||
dim = int(np.prod(s.shape))
|
dim = int(np.prod(s.shape))
|
||||||
if type(s) == specs.Array:
|
if type(s) == specs.Array:
|
||||||
bound = np.inf * np.ones(dim, dtype=np.float32)
|
bound = np.inf * np.ones(dim, dtype=s.dtype)
|
||||||
return -bound, bound
|
return -bound, bound
|
||||||
elif type(s) == specs.BoundedArray:
|
elif type(s) == specs.BoundedArray:
|
||||||
zeros = np.zeros(dim, dtype=np.float32)
|
zeros = np.zeros(dim, dtype=s.dtype)
|
||||||
return s.minimum + zeros, s.maximum + zeros
|
return s.minimum + zeros, s.maximum + zeros
|
||||||
|
|
||||||
mins, maxs = [], []
|
mins, maxs = [], []
|
||||||
@ -29,7 +29,7 @@ def _spec_to_box(spec):
|
|||||||
low = np.concatenate(mins, axis=0)
|
low = np.concatenate(mins, axis=0)
|
||||||
high = np.concatenate(maxs, axis=0)
|
high = np.concatenate(maxs, axis=0)
|
||||||
assert low.shape == high.shape
|
assert low.shape == high.shape
|
||||||
return spaces.Box(low, high, dtype=np.float32)
|
return spaces.Box(low, high, dtype=s.dtype)
|
||||||
|
|
||||||
|
|
||||||
def _flatten_obs(obs: collections.MutableMapping):
|
def _flatten_obs(obs: collections.MutableMapping):
|
||||||
@ -113,7 +113,7 @@ class DMCWrapper(core.Env):
|
|||||||
if self._channels_first:
|
if self._channels_first:
|
||||||
obs = obs.transpose(2, 0, 1).copy()
|
obs = obs.transpose(2, 0, 1).copy()
|
||||||
else:
|
else:
|
||||||
obs = _flatten_obs(time_step.observation)
|
obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user