Add manual controll override and pause functionality
This commit is contained in:
parent
60892ee145
commit
4f3c8bec8b
@ -6,7 +6,6 @@ import pygame
|
|||||||
import random as random_dont_use
|
import random as random_dont_use
|
||||||
from os import urandom
|
from os import urandom
|
||||||
import math
|
import math
|
||||||
from time import time as get_time
|
|
||||||
from columbus import entities, observables
|
from columbus import entities, observables
|
||||||
|
|
||||||
|
|
||||||
@ -45,11 +44,14 @@ class ColumbusEnv(gym.Env):
|
|||||||
self.void_barrier = True
|
self.void_barrier = True
|
||||||
self.void_damage = 100
|
self.void_damage = 100
|
||||||
|
|
||||||
|
self.paused = False
|
||||||
|
self.keypress_timeout = 0
|
||||||
self.rng = random_dont_use.Random()
|
self.rng = random_dont_use.Random()
|
||||||
self._seed(self.env_seed)
|
self._seed(self.env_seed)
|
||||||
self.reset()
|
|
||||||
|
|
||||||
self.observation_space = self.observable.get_observation_space()
|
@property
|
||||||
|
def observation_space(self):
|
||||||
|
return self.observable.get_observation_space()
|
||||||
|
|
||||||
def _seed(self, seed):
|
def _seed(self, seed):
|
||||||
if seed == None:
|
if seed == None:
|
||||||
@ -111,10 +113,14 @@ class ColumbusEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
inp = action[0], action[1]
|
inp = action[0], action[1]
|
||||||
|
if self._disturb_next:
|
||||||
|
inp = self._disturb_next
|
||||||
|
self._disturb_next = False
|
||||||
if self.limit_inp_to_unit_circle:
|
if self.limit_inp_to_unit_circle:
|
||||||
inp = self._limit_to_unit_circle(((inp[0]-0.5)*2, (inp[1]-0.5)*2))
|
inp = self._limit_to_unit_circle(((inp[0]-0.5)*2, (inp[1]-0.5)*2))
|
||||||
inp = (inp[0]+1)/2, (inp[1]+1)/2
|
inp = (inp[0]+1)/2, (inp[1]+1)/2
|
||||||
self.inp = inp
|
self.inp = inp
|
||||||
|
if not self.paused:
|
||||||
self._step_timers()
|
self._step_timers()
|
||||||
self._step_entities()
|
self._step_entities()
|
||||||
observation = self.observable.get_observation()
|
observation = self.observable.get_observation()
|
||||||
@ -168,8 +174,8 @@ class ColumbusEnv(gym.Env):
|
|||||||
pygame.init()
|
pygame.init()
|
||||||
self._seed(self.env_seed)
|
self._seed(self.env_seed)
|
||||||
self._rendered = False
|
self._rendered = False
|
||||||
|
self._disturb_next = False
|
||||||
self.inp = (0.5, 0.5)
|
self.inp = (0.5, 0.5)
|
||||||
self.keypress_timeout = 0
|
|
||||||
# will get rescaled acording to fps (=reward per second)
|
# will get rescaled acording to fps (=reward per second)
|
||||||
self.new_reward = 0
|
self.new_reward = 0
|
||||||
self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards
|
self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards
|
||||||
@ -193,26 +199,47 @@ class ColumbusEnv(gym.Env):
|
|||||||
def _draw_joystick(self, forceDraw=False):
|
def _draw_joystick(self, forceDraw=False):
|
||||||
if (self.draw_joystick or forceDraw) and self.visible:
|
if (self.draw_joystick or forceDraw) and self.visible:
|
||||||
x, y = self.inp
|
x, y = self.inp
|
||||||
pygame.draw.circle(self.screen, (100, 100, 100), (50 +
|
bigcol = (100, 100, 100)
|
||||||
|
smolcol = (100, 100, 100)
|
||||||
|
if self._disturb_next:
|
||||||
|
smolcol = (255, 255, 255)
|
||||||
|
pygame.draw.circle(self.screen, bigcol, (50 +
|
||||||
self.joystick_offset[0], 50+self.joystick_offset[1]), 50, width=1)
|
self.joystick_offset[0], 50+self.joystick_offset[1]), 50, width=1)
|
||||||
pygame.draw.circle(self.screen, (100, 100, 100), (20+int(60*x) +
|
pygame.draw.circle(self.screen, smolcol, (20+int(60*x) +
|
||||||
self.joystick_offset[0], 20+int(60*y)+self.joystick_offset[1]), 20, width=0)
|
self.joystick_offset[0], 20+int(60*y)+self.joystick_offset[1]), 20, width=0)
|
||||||
|
|
||||||
def render(self, mode='human', dont_show=False):
|
def _handle_user_input(self):
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
pass
|
pass
|
||||||
keys = pygame.key.get_pressed()
|
keys = pygame.key.get_pressed()
|
||||||
if self.keypress_timeout == 0:
|
if self.keypress_timeout == 0:
|
||||||
self.keypress_timeout = int(self.fps/2)
|
self.keypress_timeout = int(self.fps/5)
|
||||||
if keys[pygame.K_m]:
|
if keys[pygame.K_m]:
|
||||||
self.draw_entities = not self.draw_entities
|
self.draw_entities = not self.draw_entities
|
||||||
elif keys[pygame.K_r]:
|
elif keys[pygame.K_r]:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
elif keys[pygame.K_p]:
|
||||||
|
self.paused = not self.paused
|
||||||
else:
|
else:
|
||||||
self.keypress_timeout = 0
|
self.keypress_timeout = 0
|
||||||
else:
|
else:
|
||||||
self.keypress_timeout -= 1
|
self.keypress_timeout -= 1
|
||||||
|
|
||||||
|
# keys, that can be hold down to continously trigger them
|
||||||
|
if keys[pygame.K_q]:
|
||||||
|
self._disturb_next = (
|
||||||
|
random_dont_use.random(), random_dont_use.random())
|
||||||
|
elif keys[pygame.K_w]:
|
||||||
|
self._disturb_next = (0.5, 0.0)
|
||||||
|
elif keys[pygame.K_a]:
|
||||||
|
self._disturb_next = (0.0, 0.5)
|
||||||
|
elif keys[pygame.K_s]:
|
||||||
|
self._disturb_next = (0.5, 1.0)
|
||||||
|
elif keys[pygame.K_d]:
|
||||||
|
self._disturb_next = (1.0, 0.5)
|
||||||
|
|
||||||
|
def render(self, mode='human', dont_show=False):
|
||||||
|
self._handle_user_input()
|
||||||
self.visible = self.visible or not dont_show
|
self.visible = self.visible or not dont_show
|
||||||
self._ensure_surface()
|
self._ensure_surface()
|
||||||
pygame.draw.rect(self.surface, (0, 0, 0),
|
pygame.draw.rect(self.surface, (0, 0, 0),
|
||||||
@ -363,11 +390,12 @@ class ColumbusJustState(ColumbusEnv):
|
|||||||
|
|
||||||
|
|
||||||
class ColumbusStateWithBarriers(ColumbusEnv):
|
class ColumbusStateWithBarriers(ColumbusEnv):
|
||||||
def __init__(self, observable=observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True), fps=30, env_seed=3.141):
|
def __init__(self, observable=observables.StateObservable(coordsAgent=True, speedAgent=False, coordsRelativeToAgent=False, coordsRewards=True, rewardsWhitelist=None, coordsEnemys=True, enemysWhitelist=None, enemysNoBarriers=True, rewardsTimeouts=False, include_rand=True), fps=30, env_seed=3.141, num_chasers=1):
|
||||||
super(ColumbusStateWithBarriers, self).__init__(
|
super(ColumbusStateWithBarriers, self).__init__(
|
||||||
observable=observable, fps=fps, env_seed=env_seed)
|
observable=observable, fps=fps, env_seed=env_seed)
|
||||||
self.aux_reward_max = 0.01
|
self.aux_reward_max = 10
|
||||||
self.start_pos = (0.5, 0.5)
|
self.start_pos = (0.5, 0.5)
|
||||||
|
self.num_chasers = num_chasers
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.agent.pos = self.start_pos
|
self.agent.pos = self.start_pos
|
||||||
@ -375,7 +403,7 @@ class ColumbusStateWithBarriers(ColumbusEnv):
|
|||||||
enemy = entities.CircleBarrier(self)
|
enemy = entities.CircleBarrier(self)
|
||||||
enemy.radius = self.random()*25+75
|
enemy.radius = self.random()*25+75
|
||||||
self.entities.append(enemy)
|
self.entities.append(enemy)
|
||||||
for i in range(1):
|
for i in range(self.num_chasers):
|
||||||
enemy = entities.FlyingChaser(self)
|
enemy = entities.FlyingChaser(self)
|
||||||
enemy.chase_acc = 0.55 # *0.6+0.5
|
enemy.chase_acc = 0.55 # *0.6+0.5
|
||||||
self.entities.append(enemy)
|
self.entities.append(enemy)
|
||||||
@ -385,8 +413,14 @@ class ColumbusStateWithBarriers(ColumbusEnv):
|
|||||||
self.entities.append(reward)
|
self.entities.append(reward)
|
||||||
|
|
||||||
|
|
||||||
###
|
class ColumbusTrivialRay(ColumbusStateWithBarriers):
|
||||||
|
def __init__(self, observable=observables.RayObservable(num_rays=8, ray_len=512), hide_map=False, fps=30):
|
||||||
|
super(ColumbusTrivialRay, self).__init__(
|
||||||
|
observable=observable, fps=fps, num_chasers=0)
|
||||||
|
self.draw_entities = not hide_map
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
register(
|
register(
|
||||||
id='ColumbusTestCnn-v0',
|
id='ColumbusTestCnn-v0',
|
||||||
entry_point=ColumbusTest3_1,
|
entry_point=ColumbusTest3_1,
|
||||||
@ -428,3 +462,9 @@ register(
|
|||||||
entry_point=ColumbusStateWithBarriers,
|
entry_point=ColumbusStateWithBarriers,
|
||||||
max_episode_steps=30*60*2,
|
max_episode_steps=30*60*2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ColumbusTrivialRay-v0',
|
||||||
|
entry_point=ColumbusTrivialRay,
|
||||||
|
max_episode_steps=30*60*2,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user