added dtype in observation spaces
This commit is contained in:
parent
ca90f257d4
commit
f9e7a34eda
@ -106,7 +106,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
# self._tmp_hole_depth,
|
||||
self.end_effector - self._goal,
|
||||
self._steps
|
||||
])
|
||||
]).astype(np.float32)
|
||||
|
||||
def _get_line_points(self, num_points_per_link=1):
|
||||
theta = self._joint_angles[:, None]
|
||||
|
@ -73,7 +73,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
self._angle_velocity,
|
||||
self.end_effector - self._goal,
|
||||
self._steps
|
||||
])
|
||||
]).astype(np.float32)
|
||||
|
||||
def _generate_goal(self):
|
||||
|
||||
|
@ -111,7 +111,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
self.end_effector - self._via_point,
|
||||
self.end_effector - self._goal,
|
||||
self._steps
|
||||
])
|
||||
]).astype(np.float32)
|
||||
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision()
|
||||
|
@ -15,10 +15,10 @@ def _spec_to_box(spec):
|
||||
assert s.dtype == np.float64 or s.dtype == np.float32, f"Only float64 and float32 types are allowed, instead {s.dtype} was found"
|
||||
dim = int(np.prod(s.shape))
|
||||
if type(s) == specs.Array:
|
||||
bound = np.inf * np.ones(dim, dtype=np.float32)
|
||||
bound = np.inf * np.ones(dim, dtype=s.dtype)
|
||||
return -bound, bound
|
||||
elif type(s) == specs.BoundedArray:
|
||||
zeros = np.zeros(dim, dtype=np.float32)
|
||||
zeros = np.zeros(dim, dtype=s.dtype)
|
||||
return s.minimum + zeros, s.maximum + zeros
|
||||
|
||||
mins, maxs = [], []
|
||||
@ -29,7 +29,7 @@ def _spec_to_box(spec):
|
||||
low = np.concatenate(mins, axis=0)
|
||||
high = np.concatenate(maxs, axis=0)
|
||||
assert low.shape == high.shape
|
||||
return spaces.Box(low, high, dtype=np.float32)
|
||||
return spaces.Box(low, high, dtype=s.dtype)
|
||||
|
||||
|
||||
def _flatten_obs(obs: collections.MutableMapping):
|
||||
@ -113,7 +113,7 @@ class DMCWrapper(core.Env):
|
||||
if self._channels_first:
|
||||
obs = obs.transpose(2, 0, 1).copy()
|
||||
else:
|
||||
obs = _flatten_obs(time_step.observation)
|
||||
obs = _flatten_obs(time_step.observation).astype(self.observation_space.dtype)
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
Loading…
Reference in New Issue
Block a user