Implemented Visualization of Value-Function

This commit is contained in:
Dominik Moritz Roth 2022-10-15 20:54:16 +02:00
parent e09a588e77
commit 66c9509c27

View File

@ -94,6 +94,7 @@ class ColumbusEnv(gym.Env):
self.max_steps = max_steps
self._steps = 0
self._has_value_map = False
self.paused = False
self.keypress_timeout = 0
@ -128,6 +129,8 @@ class ColumbusEnv(gym.Env):
self.surface = pygame.Surface((self.width, self.height))
self.path_overlay = pygame.Surface(
(self.width, self.height), pygame.SRCALPHA, 32)
self.value_overlay = pygame.Surface(
(self.width, self.height), pygame.SRCALPHA, 32)
if self.visible:
self.screen = pygame.display.set_mode(
(self.width, self.height))
@ -267,6 +270,7 @@ class ColumbusEnv(gym.Env):
pygame.init()
self._init = True
self._steps = 0
self._has_value_map = False
self._seed(self.env_seed)
self._rendered = False
self._disturb_next = False
@ -294,6 +298,42 @@ class ColumbusEnv(gym.Env):
for entity in self.entities:
entity.draw()
def _draw_values(self, value_func, static=True, resolution=64, color_depth=255):
if not (static and self._has_value_map):
agentpos = self.agent.pos
agentspeed = self.agent.speed
self.agent.speed = (0, 0)
self.value_overlay = pygame.Surface(
(self.width, self.height), pygame.SRCALPHA, 32)
obs = []
for i in range(resolution):
for j in range(resolution):
x, y = i*(self.width/resolution), j * \
(self.height/resolution)
self.agent.pos = x, y
ob = self.observable.get_observation()
obs.append(ob)
self.agent.pos = agentpos
self.agent.speed = agentspeed
V = value_func(th.Tensor(obs))
V -= V.min()
V /= V.max()
c = 0
for i in range(resolution):
for j in range(resolution):
v = V[c]
c += 1
col = [int((1-c)*color_depth), int(c*color_depth), 0]
x, y = i*(self.width/resolution), j * \
(self.height/resolution)
rect = pygame.Rect(x, y, self.env.width/resolution,
self.height/resolution)
pygame.draw.rect(self.value_overlay, col,
rect, width=0)
self.screen.blit(self.value_overlay, (0, 0))
def _draw_observable(self, forceDraw=False):
if self.draw_observable and (self.visible or forceDraw):
self.observable.draw()
@ -386,7 +426,7 @@ class ColumbusEnv(gym.Env):
elif keys[pygame.K_d]:
self._disturb_next = (1.0, 0.5)
def render(self, mode='human', dont_show=False, chol=None):
def render(self, mode='human', dont_show=False, chol=None, value_func=None, values_static=True):
if mode == 'human':
self._handle_user_input()
self.visible = self.visible or not dont_show
@ -400,8 +440,10 @@ class ColumbusEnv(gym.Env):
self._rendered = True
if mode == 'human' and dont_show:
return
self.screen.blit(self.surface, (0, 0))
if value_func != None:
self._draw_values(value_func, values_static)
self.screen.blit(self.path_overlay, (0, 0))
self.screen.blit(self.surface, (0, 0))
self._draw_observable(forceDraw=mode != 'human')
self._draw_joystick(forceDraw=mode != 'human')
if chol != None: