Implemented Visualization of Value-Function
This commit is contained in:
parent
e09a588e77
commit
66c9509c27
@ -94,6 +94,7 @@ class ColumbusEnv(gym.Env):
|
||||
|
||||
self.max_steps = max_steps
|
||||
self._steps = 0
|
||||
self._has_value_map = False
|
||||
|
||||
self.paused = False
|
||||
self.keypress_timeout = 0
|
||||
@ -128,6 +129,8 @@ class ColumbusEnv(gym.Env):
|
||||
self.surface = pygame.Surface((self.width, self.height))
|
||||
self.path_overlay = pygame.Surface(
|
||||
(self.width, self.height), pygame.SRCALPHA, 32)
|
||||
self.value_overlay = pygame.Surface(
|
||||
(self.width, self.height), pygame.SRCALPHA, 32)
|
||||
if self.visible:
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.width, self.height))
|
||||
@ -267,6 +270,7 @@ class ColumbusEnv(gym.Env):
|
||||
pygame.init()
|
||||
self._init = True
|
||||
self._steps = 0
|
||||
self._has_value_map = False
|
||||
self._seed(self.env_seed)
|
||||
self._rendered = False
|
||||
self._disturb_next = False
|
||||
@ -294,6 +298,42 @@ class ColumbusEnv(gym.Env):
|
||||
for entity in self.entities:
|
||||
entity.draw()
|
||||
|
||||
def _draw_values(self, value_func, static=True, resolution=64, color_depth=255):
|
||||
if not (static and self._has_value_map):
|
||||
agentpos = self.agent.pos
|
||||
agentspeed = self.agent.speed
|
||||
self.agent.speed = (0, 0)
|
||||
self.value_overlay = pygame.Surface(
|
||||
(self.width, self.height), pygame.SRCALPHA, 32)
|
||||
obs = []
|
||||
for i in range(resolution):
|
||||
for j in range(resolution):
|
||||
x, y = i*(self.width/resolution), j * \
|
||||
(self.height/resolution)
|
||||
self.agent.pos = x, y
|
||||
ob = self.observable.get_observation()
|
||||
obs.append(ob)
|
||||
self.agent.pos = agentpos
|
||||
self.agent.speed = agentspeed
|
||||
|
||||
V = value_func(th.Tensor(obs))
|
||||
V -= V.min()
|
||||
V /= V.max()
|
||||
|
||||
c = 0
|
||||
for i in range(resolution):
|
||||
for j in range(resolution):
|
||||
v = V[c]
|
||||
c += 1
|
||||
col = [int((1-c)*color_depth), int(c*color_depth), 0]
|
||||
x, y = i*(self.width/resolution), j * \
|
||||
(self.height/resolution)
|
||||
rect = pygame.Rect(x, y, self.env.width/resolution,
|
||||
self.height/resolution)
|
||||
pygame.draw.rect(self.value_overlay, col,
|
||||
rect, width=0)
|
||||
self.screen.blit(self.value_overlay, (0, 0))
|
||||
|
||||
def _draw_observable(self, forceDraw=False):
|
||||
if self.draw_observable and (self.visible or forceDraw):
|
||||
self.observable.draw()
|
||||
@ -386,7 +426,7 @@ class ColumbusEnv(gym.Env):
|
||||
elif keys[pygame.K_d]:
|
||||
self._disturb_next = (1.0, 0.5)
|
||||
|
||||
def render(self, mode='human', dont_show=False, chol=None):
|
||||
def render(self, mode='human', dont_show=False, chol=None, value_func=None, values_static=True):
|
||||
if mode == 'human':
|
||||
self._handle_user_input()
|
||||
self.visible = self.visible or not dont_show
|
||||
@ -400,8 +440,10 @@ class ColumbusEnv(gym.Env):
|
||||
self._rendered = True
|
||||
if mode == 'human' and dont_show:
|
||||
return
|
||||
self.screen.blit(self.surface, (0, 0))
|
||||
if value_func != None:
|
||||
self._draw_values(value_func, values_static)
|
||||
self.screen.blit(self.path_overlay, (0, 0))
|
||||
self.screen.blit(self.surface, (0, 0))
|
||||
self._draw_observable(forceDraw=mode != 'human')
|
||||
self._draw_joystick(forceDraw=mode != 'human')
|
||||
if chol != None:
|
||||
|
Loading…
Reference in New Issue
Block a user