diff --git a/columbus/observables.py b/columbus/observables.py index d502ae5..be97ed2 100644 --- a/columbus/observables.py +++ b/columbus/observables.py @@ -97,7 +97,7 @@ class RayObservable(Observable): self.include_rand = include_rand def get_observation_space(self): - return spaces.Box(low=0, high=self.num_steps, + return spaces.Box(low=0, high=1, shape=(self.num_rays+self.include_rand, self.num_chans), dtype=np.uint8) def _get_ray_heads(self): @@ -148,7 +148,7 @@ class RayObservable(Observable): self.env.agent.pos[0]*self.env.width, sy + \ self.env.agent.pos[1]*self.env.height if self._check_collision((rx, ry), entity_type, entities): - self.rays[r, c] = self.num_steps-s + self.rays[r, c] = (self.num_steps-s)/self.num_steps if self.occlusion: occ_dist = s break @@ -157,7 +157,7 @@ class RayObservable(Observable): def draw(self): for c, entity_type in enumerate(self.chans): for r, (hx, hy) in enumerate(self._get_ray_heads()): - s = self.num_steps - self.rays[r, c] + s = self.num_steps - self.rays[r, c]*self.num_steps sx, sy = (s+1)*hx/self.num_steps, (s+1)*hy/self.num_steps rx, ry = sx + \ self.env.agent.pos[0]*self.env.width, sy + \