Fixed value_func- and path-rendering not working
This commit is contained in:
parent
70536b79e7
commit
e371d871e6
@ -298,7 +298,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
for entity in self.entities:
|
for entity in self.entities:
|
||||||
entity.draw()
|
entity.draw()
|
||||||
|
|
||||||
def _draw_values(self, value_func, static=True, resolution=64, color_depth=255):
|
def _draw_values(self, value_func, static=True, resolution=64, color_depth=192):
|
||||||
if not (static and self._has_value_map):
|
if not (static and self._has_value_map):
|
||||||
agentpos = self.agent.pos
|
agentpos = self.agent.pos
|
||||||
agentspeed = self.agent.speed
|
agentspeed = self.agent.speed
|
||||||
@ -308,7 +308,7 @@ class ColumbusEnv(gym.Env):
|
|||||||
obs = []
|
obs = []
|
||||||
for i in range(resolution):
|
for i in range(resolution):
|
||||||
for j in range(resolution):
|
for j in range(resolution):
|
||||||
x, y = i*(self.width/resolution), j * \
|
x, y = (i+0.5)*(self.width/resolution), (j+0.5) *\
|
||||||
(self.height/resolution)
|
(self.height/resolution)
|
||||||
self.agent.pos = x, y
|
self.agent.pos = x, y
|
||||||
ob = self.observable.get_observation()
|
ob = self.observable.get_observation()
|
||||||
@ -317,22 +317,24 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.agent.speed = agentspeed
|
self.agent.speed = agentspeed
|
||||||
|
|
||||||
V = value_func(th.Tensor(obs))
|
V = value_func(th.Tensor(obs))
|
||||||
V -= V.min()
|
V /= max(V.max(), -1*V.min())*2
|
||||||
V /= V.max()
|
V += 0.5
|
||||||
|
|
||||||
c = 0
|
c = 0
|
||||||
for i in range(resolution):
|
for i in range(resolution):
|
||||||
for j in range(resolution):
|
for j in range(resolution):
|
||||||
v = V[c]
|
v = V[c]
|
||||||
c += 1
|
c += 1
|
||||||
col = [int((1-c)*color_depth), int(c*color_depth), 0]
|
col = [int((1-v)*color_depth),
|
||||||
|
int(v*color_depth), 0, color_depth]
|
||||||
x, y = i*(self.width/resolution), j * \
|
x, y = i*(self.width/resolution), j * \
|
||||||
(self.height/resolution)
|
(self.height/resolution)
|
||||||
rect = pygame.Rect(x, y, self.env.width/resolution,
|
rect = pygame.Rect(x, y, int(self.width/resolution)+1,
|
||||||
self.height/resolution)
|
int(self.height/resolution)+1)
|
||||||
pygame.draw.rect(self.value_overlay, col,
|
pygame.draw.rect(self.value_overlay, col,
|
||||||
rect, width=0)
|
rect, width=0)
|
||||||
self.screen.blit(self.value_overlay, (0, 0))
|
self.surface.blit(self.value_overlay, (0, 0))
|
||||||
|
self._has_value_map = True
|
||||||
|
|
||||||
def _draw_observable(self, forceDraw=False):
|
def _draw_observable(self, forceDraw=False):
|
||||||
if self.draw_observable and (self.visible or forceDraw):
|
if self.draw_observable and (self.visible or forceDraw):
|
||||||
@ -433,6 +435,9 @@ class ColumbusEnv(gym.Env):
|
|||||||
self._ensure_surface()
|
self._ensure_surface()
|
||||||
pygame.draw.rect(self.surface, (0, 0, 0),
|
pygame.draw.rect(self.surface, (0, 0, 0),
|
||||||
pygame.Rect(0, 0, self.width, self.height))
|
pygame.Rect(0, 0, self.width, self.height))
|
||||||
|
if value_func != None:
|
||||||
|
self._draw_values(value_func, values_static)
|
||||||
|
self.surface.blit(self.path_overlay, (0, 0))
|
||||||
if self.draw_entities:
|
if self.draw_entities:
|
||||||
self._draw_entities()
|
self._draw_entities()
|
||||||
else:
|
else:
|
||||||
@ -440,9 +445,6 @@ class ColumbusEnv(gym.Env):
|
|||||||
self._rendered = True
|
self._rendered = True
|
||||||
if mode == 'human' and dont_show:
|
if mode == 'human' and dont_show:
|
||||||
return
|
return
|
||||||
if value_func != None:
|
|
||||||
self._draw_values(value_func, values_static)
|
|
||||||
self.screen.blit(self.path_overlay, (0, 0))
|
|
||||||
self.screen.blit(self.surface, (0, 0))
|
self.screen.blit(self.surface, (0, 0))
|
||||||
self._draw_observable(forceDraw=mode != 'human')
|
self._draw_observable(forceDraw=mode != 'human')
|
||||||
self._draw_joystick(forceDraw=mode != 'human')
|
self._draw_joystick(forceDraw=mode != 'human')
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import torch as th
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
@ -69,11 +70,16 @@ def chooseEnv():
|
|||||||
return Env(fps=30)
|
return Env(fps=30)
|
||||||
|
|
||||||
|
|
||||||
|
def value_func(obs):
|
||||||
|
return th.rand(obs.shape[0])-0.5
|
||||||
|
|
||||||
|
|
||||||
def playEnv(env):
|
def playEnv(env):
|
||||||
done = False
|
done = False
|
||||||
env.reset()
|
env.reset()
|
||||||
while not done:
|
while not done:
|
||||||
t1 = time()
|
t1 = time()
|
||||||
|
# env.render(value_func=value_func)
|
||||||
env.render()
|
env.render()
|
||||||
pos = (0.5, 0.5)
|
pos = (0.5, 0.5)
|
||||||
pos = pygame.mouse.get_pos()
|
pos = pygame.mouse.get_pos()
|
||||||
|
Loading…
Reference in New Issue
Block a user