diff --git a/columbus/env.py b/columbus/env.py index e4fc218..2a059f2 100644 --- a/columbus/env.py +++ b/columbus/env.py @@ -94,6 +94,7 @@ class ColumbusEnv(gym.Env): self.max_steps = max_steps self._steps = 0 + self._has_value_map = False self.paused = False self.keypress_timeout = 0 @@ -128,6 +129,8 @@ class ColumbusEnv(gym.Env): self.surface = pygame.Surface((self.width, self.height)) self.path_overlay = pygame.Surface( (self.width, self.height), pygame.SRCALPHA, 32) + self.value_overlay = pygame.Surface( + (self.width, self.height), pygame.SRCALPHA, 32) if self.visible: self.screen = pygame.display.set_mode( (self.width, self.height)) @@ -267,6 +270,7 @@ class ColumbusEnv(gym.Env): pygame.init() self._init = True self._steps = 0 + self._has_value_map = False self._seed(self.env_seed) self._rendered = False self._disturb_next = False @@ -294,6 +298,42 @@ class ColumbusEnv(gym.Env): for entity in self.entities: entity.draw() + def _draw_values(self, value_func, static=True, resolution=64, color_depth=255): + if not (static and self._has_value_map): + agentpos = self.agent.pos + agentspeed = self.agent.speed + self.agent.speed = (0, 0) + self.value_overlay = pygame.Surface( + (self.width, self.height), pygame.SRCALPHA, 32) + obs = [] + for i in range(resolution): + for j in range(resolution): + x, y = i*(self.width/resolution), j * \ + (self.height/resolution) + self.agent.pos = x, y + ob = self.observable.get_observation() + obs.append(ob) + self.agent.pos = agentpos + self.agent.speed = agentspeed + + V = value_func(th.Tensor(obs)) + V -= V.min() + V /= V.max() + + c = 0 + for i in range(resolution): + for j in range(resolution): + v = V[c] + c += 1 + col = [int((1-c)*color_depth), int(c*color_depth), 0] + x, y = i*(self.width/resolution), j * \ + (self.height/resolution) + rect = pygame.Rect(x, y, self.env.width/resolution, + self.height/resolution) + pygame.draw.rect(self.value_overlay, col, + rect, width=0) + self.screen.blit(self.value_overlay, (0, 0)) + def _draw_observable(self, forceDraw=False): if self.draw_observable and (self.visible or forceDraw): self.observable.draw() @@ -386,7 +426,7 @@ class ColumbusEnv(gym.Env): elif keys[pygame.K_d]: self._disturb_next = (1.0, 0.5) - def render(self, mode='human', dont_show=False, chol=None): + def render(self, mode='human', dont_show=False, chol=None, value_func=None, values_static=True): if mode == 'human': self._handle_user_input() self.visible = self.visible or not dont_show @@ -400,8 +440,10 @@ class ColumbusEnv(gym.Env): self._rendered = True if mode == 'human' and dont_show: return - self.screen.blit(self.surface, (0, 0)) + if value_func != None: + self._draw_values(value_func, values_static) self.screen.blit(self.path_overlay, (0, 0)) + self.screen.blit(self.surface, (0, 0)) self._draw_observable(forceDraw=mode != 'human') self._draw_joystick(forceDraw=mode != 'human') if chol != None: