From fd6edb02f716fa7d40468101797de231adc20c00 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 19 Jun 2022 15:46:03 +0200 Subject: [PATCH] Squashed 'subtrees/columbus/' content from commit 12cac4d git-subtree-dir: subtrees/columbus git-subtree-split: 12cac4db9912238e4daf6fb259666a362329c6a8 --- .gitignore | 3 + entities.py | 195 +++++++++++++++++++++++++++++++++++++++++++++++ env.py | 201 +++++++++++++++++++++++++++++++++++++++++++++++++ humanPlayer.py | 43 +++++++++++ img_README.png | Bin 0 -> 11024 bytes observables.py | 69 +++++++++++++++++ 6 files changed, 511 insertions(+) create mode 100644 .gitignore create mode 100644 entities.py create mode 100644 env.py create mode 100644 humanPlayer.py create mode 100644 img_README.png create mode 100644 observables.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d46b6b --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pychache__ +*.pyc +*.pyo diff --git a/entities.py b/entities.py new file mode 100644 index 0000000..c03c7ca --- /dev/null +++ b/entities.py @@ -0,0 +1,195 @@ +import pygame +import math + + +class Entity(object): + def __init__(self, env): + self.env = env + self.pos = (env.random(), env.random()) + self.speed = (0, 0) + self.acc = (0, 0) + self.drag = 0 + self.radius = 10 + self.col = (255, 255, 255) + self.shape = 'circle' + + def physics_step(self): + x, y = self.pos + vx, vy = self.speed + ax, ay = self.acc + vx, vy = vx+ax*self.env.acc_fac, vy+ay*self.env.acc_fac + x, y = x+vx*self.env.speed_fac, y+vy*self.env.speed_fac + if x > 1 or x < 0: + x = min(max(x, 0), 1) + vx = 0 + if y > 1 or y < 0: + y = min(max(y, 0), 1) + vy = 0 + self.speed = vx/(1+self.drag), vy/(1+self.drag) + self.pos = x, y + + def controll_step(self): + pass + + def step(self): + self.controll_step() + self.physics_step() + + def draw(self): + x, y = self.pos + pygame.draw.circle(self.env.surface, self.col, + (x*self.env.width, y*self.env.height), self.radius, width=0) + + def on_collision(self, other): + pass + + def kill(self): + self.env.kill_entity(self) + + +class Agent(Entity): + def __init__(self, env): + super(Agent, self).__init__(env) + self.pos = (0.5, 0.5) + self.col = (0, 0, 255) + self.drag = self.env.agent_drag + self.controll_type = self.env.controll_type + + def controll_step(self): + self._read_input() + self.env.check_collisions_for(self) + + def _read_input(self): + if self.controll_type == 'SPEED': + self.speed = self.env.inp[0] - 0.5, self.env.inp[1] - 0.5 + elif self.controll_type == 'ACC': + self.acc = self.env.inp[0] - 0.5, self.env.inp[1] - 0.5 + else: + raise Exception('Unsupported controll_type') + + +class Enemy(Entity): + def __init__(self, env): + super(Enemy, self).__init__(env) + self.col = (255, 0, 0) + self.damage = 10 + + def on_collision(self, other): + if isinstance(other, Agent): + self.env.new_reward -= self.damage + + +class Barrier(Enemy): + def __init__(self, env): + super(Barrier, self).__init__(env) + + +class CircleBarrier(Barrier): + def __init__(self, env): + super(CircleBarrier, self).__init__(env) + + +class Chaser(Enemy): + def __init__(self, env): + super(Chaser, self).__init__(env) + self.target = self.env.agent + self.arrow_fak = 100 + self.lookahead = 0 + + def _get_arrow(self): + tx, ty = self.target.pos + x, y = self.pos + fx, fy = x + self.speed[0]*self.lookahead*self.env.speed_fac, y + \ + self.speed[1]*self.lookahead*self.env.speed_fac + dx, dy = (tx-fx)*self.arrow_fak, (ty-fy)*self.arrow_fak + return self.env._limit_to_unit_circle((dx, dy)) + + +class WalkingChaser(Chaser): + def __init__(self, env): + super(WalkingChaser, self).__init__(env) + self.col = (255, 0, 0) + self.chase_speed = 0.45 + + def controll_step(self): + arrow = self._get_arrow() + self.speed = arrow[0] * self.chase_speed, arrow[1] * self.chase_speed + + +class FlyingChaser(Chaser): + def __init__(self, env): + super(FlyingChaser, self).__init__(env) + self.col = (255, 0, 0) + self.chase_acc = 0.5 + self.arrow_fak = 5 + self.lookahead = 8 + env.random()*2 + + def controll_step(self): + arrow = self._get_arrow() + self.acc = arrow[0] * self.chase_acc, arrow[1] * self.chase_acc + + +class Reward(Entity): + def __init__(self, env): + super(Reward, self).__init__(env) + self.col = (0, 255, 0) + self.avaible = True + self.enforce_not_on_barrier = False + self.reward = 1 + + def on_collision(self, other): + if isinstance(other, Agent): + self.on_collect() + elif isinstance(other, Barrier): + self.on_barrier_collision() + + def on_collect(self): + self.env.new_reward += self.reward + + def on_barrier_collision(self): + if self.enforce_not_on_barrier: + self.pos = (self.env.random(), self.env.random()) + self.env.check_collisions_for(self) + + +class OnceReward(Reward): + def __init__(self, env): + super(OnceReward, self).__init__(env) + self.reward = 100 + + def on_collect(self): + self.env.new_abs_reward += self.reward + self.kill() + + +class TeleportingReward(OnceReward): + def __init__(self, env): + super(TeleportingReward, self).__init__(env) + self.enforce_not_on_barrier = True + self.env.check_collisions_for(self) + + def on_collect(self): + self.env.new_abs_reward += self.reward + self.pos = (self.env.random(), self.env.random()) + self.env.check_collisions_for(self) + + +class TimeoutReward(OnceReward): + def __init__(self, env): + super(TimeoutReward, self).__init__(env) + self.enforce_not_on_barrier = True + self.env.check_collisions_for(self) + self.timeout = 10 + + def set_avaible(self, value): + self.avaible = value + if self.avaible: + self.col = (0, 255, 0) + else: + self.col = (50, 100, 50) + + def on_collect(self): + if self.avaible: + self.env.new_abs_reward += self.reward + self.set_avaible(False) + self.env.timers.append((self.timeout, self.set_avaible, True)) diff --git a/env.py b/env.py new file mode 100644 index 0000000..d189a4e --- /dev/null +++ b/env.py @@ -0,0 +1,201 @@ +import gym +from gym import spaces +import numpy as np +import pygame +import random as random_dont_use +import math +import entities +import observables + + +class ColumbusEnv(gym.Env): + metadata = {'render.modes': ['human']} + + def __init__(self, observable=observables.Observable(), fps=60, env_seed=3.1): + super(ColumbusEnv, self).__init__() + self.action_space = spaces.Box( + low=0, high=1, shape=(2,), dtype=np.float32) + observable._set_env(self) + self.observable = observable + self.observation_space = self.observable.get_observation_space() + self.title = 'Untitled' + self.fps = fps + self.env_seed = env_seed + self.joystick_offset = (10, 10) + self.surface = None + self.screen = None + self.width = 720 + self.height = 720 + self.speed_fac = 0.01/fps*60 + self.acc_fac = 0.03/fps*60 + self.agent_drag = 0 # 0.01 is a good value + self.controll_type = 'SPEED' # one of SPEED, ACC + self.limit_inp_to_unit_circle = True + self.aux_reward_max = 0 # 0 = off + self.aux_reward_discretize = 0 # 0 = dont discretize + self.draw_observable = True + self.draw_joystick = True + + self.rng = random_dont_use.Random() + self.reset() + + def _seed(self, seed): + self.rng.seed(seed) + + def random(self): + return self.rng.random() + + def _ensure_surface(self): + if not self.surface: + self.surface = pygame.Surface((self.width, self.height)) + self.screen = pygame.display.set_mode((self.width, self.height)) + pygame.display.set_caption(self.title) + + def _limit_to_unit_circle(self, coords): + l_sq = coords[0]**2 + coords[1]**2 + if l_sq > 1: + l = math.sqrt(l_sq) + coords = coords[0] / l, coords[1] / l + return coords + + def _step_entities(self): + for entity in self.entities: + entity.step() + + def _step_timers(self): + new_timers = [] + for time_left, func, arg in self.timers: + time_left -= 1/self.fps + if time_left < 0: + func(arg) + else: + new_timers.append((time_left, func, arg)) + self.timers = new_timers + + def sq_dist(self, entity1, entity2): + return (entity1.pos[0] - entity2.pos[0])**2 + (entity1.pos[1] - entity2.pos[1])**2 + + def dist(self, entity1, entity2): + return math.sqrt(self._sq_dist(entity1, entity2)) + + def _get_aux_reward(self): + aux_reward = 0 + for entity in self.entities: + if isinstance(entity, entities.Reward): + if entity.avaible: + reward = self.aux_reward_max / \ + (1 + self.sq_dist(entity, self.agent)) + + if self.aux_reward_discretize: + reward = int(reward*self.aux_reward_discretize*2) / \ + self.aux_reward_discretize / 2 + + aux_reward += reward + return aux_reward + + def step(self, action): + inp = action[0], action[1] + if self.limit_inp_to_unit_circle: + inp = self._limit_to_unit_circle(((inp[0]-0.5)*2, (inp[1]-0.5)*2)) + inp = (inp[0]+1)/2, (inp[1]+1)/2 + self.inp = inp + self._step_timers() + self._step_entities() + observation = self.observable.get_observation() + reward, self.new_reward, self.new_abs_reward = self.new_reward / \ + self.fps + self.new_abs_reward, 0, 0 + self.score += reward # aux_reward does not count towards the score + if self.aux_reward_max: + reward += self._get_aux_reward() + return observation, reward, 0, self.score + return observation, reward, done, info + + def check_collisions_for(self, entity): + for other in self.entities: + if other != entity: + if self._check_collision_between(entity, other): + entity.on_collision(other) + other.on_collision(entity) + + def _check_collision_between(self, e1, e2): + shapes = [e1.shape, e2.shape] + shapes.sort() + if shapes == ['circle', 'circle']: + sq_dist = ((e1.pos[0]-e2.pos[0])*self.width) ** 2 \ + + ((e1.pos[1]-e2.pos[1])*self.height)**2 + return sq_dist < (e1.radius + e2.radius)**2 + else: + raise Exception( + 'Checking for collision between unsupported shapes: '+str(shapes)) + + def kill_entity(self, target): + newEntities = [] + for entity in self.entities: + if target != entity: + newEntities.append(entity) + else: + del target + break + self.entities = newEntities + + def setup(self): + for i in range(18): + enemy = entities.CircleBarrier(self) + enemy.radius = self.random()*40+50 + self.entities.append(enemy) + for i in range(3): + enemy = entities.FlyingChaser(self) + enemy.chase_acc = self.random()*0.4*0.3 # *0.6+0.5 + self.entities.append(enemy) + for i in range(0): + reward = entities.TimeoutReward(self) + self.entities.append(reward) + for i in range(1): + reward = entities.TeleportingReward(self) + self.entities.append(reward) + + def reset(self): + pygame.init() + self.inp = (0.5, 0.5) + # will get rescaled acording to fps (=reward per second) + self.new_reward = 0 + self.new_abs_reward = 0 # will not get rescaled. should be used for one-time rewards + self.score = 0 + self.entities = [] + self.timers = [] + self.agent = entities.Agent(self) + self.setup() + self.entities.append(self.agent) # add it last, will be drawn on top + self._seed(self.env_seed) + return 0 + return observation # reward, done, info can't be included + + def _draw_entities(self): + for entity in self.entities: + entity.draw() + + def _draw_observable(self, forceDraw=False): + if self.draw_observable or forceDraw: + self.observable.draw() + + def _draw_joystick(self, forceDraw=False): + if self.draw_joystick: + x, y = self.inp + pygame.draw.circle(self.screen, (100, 100, 100), (50 + + self.joystick_offset[0], 50+self.joystick_offset[1]), 50, width=1) + pygame.draw.circle(self.screen, (100, 100, 100), (20+int(60*x) + + self.joystick_offset[0], 20+int(60*y)+self.joystick_offset[1]), 20, width=0) + + def render(self, mode='human'): + self._ensure_surface() + pygame.draw.rect(self.surface, (0, 0, 0), + pygame.Rect(0, 0, self.width, self.height)) + self._draw_entities() + self.screen.blit(self.surface, (0, 0)) + self._draw_observable() + self._draw_joystick() + pygame.display.update() + + def close(self): + pygame.display.quit() + pygame.quit() diff --git a/humanPlayer.py b/humanPlayer.py new file mode 100644 index 0000000..99bcae1 --- /dev/null +++ b/humanPlayer.py @@ -0,0 +1,43 @@ +from time import sleep, time +from env import ColumbusEnv +import numpy as np +import pygame + +from observables import Observable, CnnObservable + + +def main(): + env = ColumbusEnv(fps=60, observable=CnnObservable()) + playEnv(env) + env.close() + + +def playEnv(env): + env.reset() + done = False + while not done: + t1 = time() + env.render() + pos = (0.5, 0.5) + for event in pygame.event.get(): + pass + # if event.type == pygame.MOUSEBUTTONDOWN: + # pos = pygame.mouse.get_pos() + # print(pos) + pos = pygame.mouse.get_pos() + pos = (min(max((pos[0]-env.joystick_offset[0]-20)/60, 0), 1), + min(max((pos[1]-env.joystick_offset[1]-20)/60, 0), 1)) + obs, rew, done, info = env.step(np.array(pos, dtype=np.float32)) + print('Reward: '+str(rew)) + print('Score: '+str(info)) + t2 = time() + dt = t2 - t1 + delay = (1/env.fps - dt) + if delay < 0: + print("[!] Can't keep framerate!") + else: + sleep(delay) + + +if __name__ == '__main__': + main() diff --git a/img_README.png b/img_README.png new file mode 100644 index 0000000000000000000000000000000000000000..6f019e593026cd35ad54adb25526d90023dec639 GIT binary patch literal 11024 zcmZX4cT|(h^Y?}zO(RVaLI**;6ahh`_oAQ#K}0}0qM<2GIsv6A%?3yfC>W)KRA~x= z3K%h=2~tFQ?-JU(x!-&5?~j*rI7#-Iot>SX@|oEu26IUd#e`-80HDt6YhMO{h7A9S z40LcrQ~4?Ze$aYpo;PKHe?bh6(eO8yw~o2DiN_7^yLO&Vz}ely%}K(`!PCjf-Rq`@ z_ad@U6%zeLCDQVAvh%*>;Vxi$%gqUxcnC-<2wd=S7m$&Xkr9xVQ9L81D0fysOUK0Y zUeYra00h8!Z4Fbu^!cAQ1mFJPZ_B%N4L#vpmR*`J!>k>5kdeNWH`>wqr#(k?9^t|@ z#9L&-IK=#}-K%@ zqY|;hYr}tUzP{Ibwkl+HH7vV7bMgAzu8K>K|JTuHSFQd1Hs<=*q8hfl83Cv! z-ra5OePt7J&X8!4#Q?zC+}3jFPSrl;S=#1Xl@K8dfYsgQ zD7KHXzkjKq!1g(AGG844x3f-YNaP{FNJv7sg&`6&$t)hINTCC`@r6AeEWokW`7P(H zpJHjU^r-jQR7+DNsO?f_^Y`lIhl0wNo~s{XenNafSKW4%qs8=~O$KkMo$t<8|8j9h z{C;3p`Df>8`75`=DOSqGsIkX32LW!d_V5g8NrJ0xvZ2R1lmmd^SxJIr#Qb!YND%^@ zn-L`>uyPC95y*fBh}zEt)mig7P|1fM@1U<=r2g`}3rHEUSOnN{C`m zKGrCwwBX6?GtoT=H*;ov76~}7uj2Xa4)kK>fVwr&VsG@tfr)AO11QK|%jWYxcwa&Q zsJpF>Z12qdA$G=xhO~Lz8dbW{mL;*Xo=?AiGynT!9}Nvi)a#bur&*D+Tj`GqV{e+n}amlXM_UejqLvA1gL7`LiYa8C+_Z1-W6aD zesC)=MS`_vJ}dxzct4C~hGg?!a6|kL7T`puexr-$RY#Gyz+kNJ-~ivuu$)qe+oT#X zi5Y+t|LU>g7xo%%B0l_;Apo@8h(H5y_GDSqkD0ZGxm-GcOGj+O+sM?E7cQYS}`)2mWmkZktC zvPz(R0+*e|s^OHq)&1?B%A?qPlFv0rWZ~ztan*Y`eYYa{Zpc^~Np@FNjGUOUO9CYc zy?bPvjE2R2dJ6N5rC6oeH+cTdjoT|~S%;&VDR}pqpAMfZqe!7=wXC)oQ~=Jr)*^eT zHuTS*#fQeacJJtcNKoJ`yZKZoJiB1uov;s~*2griQ|OoP!!yn;pgc`TwLHABWBSlq zO}hD2FjPS4*0FbUuXSaaOv=iBBEbv;H(5^WD%7vlMn%6({nq!n;_sStJ`bCYc01bp zpur5|F|tIn86+qn|H2Qg(z6NEy50Axj^a1`&6ZCBm^Pd3F8=2sD<}3Y@$;wA19g>E zqT3$!N%=|Vh9f)6nl_tRdrycLhtKi+a@~r4`vW~~1wi!l>7AuluNt?;)Z8P`GNn{W zwz~J1hU+9G0KtbFXW>kh??knLdh3V~PAKxMFz0Lwi*j2Wo8?8w8{GU)nxi=oec z`z=1>??o9)h!4F*8JWxavlVc{Dm4ELIQ!v2@TXBaYh!jI55k3YZ@ZStFQZ)C!O2s>B-0P7X7?;s;(Nx{)k z{Ye_Q1A)q+DxiV2N8{tRGJhC-jSmg*<=ORmvrn4)!%p&0RorVN?Wgau!}t+8KhwKE zHTQ?!xsCrhZ|LOQpV{)UD(5yeC8fQU?()|j>rCr>1FAv=hxE=xxwvj}My1f+2E-_Rb}48 zQ7K|%J-J`Tc(w>!V8QwY&sjBoT!`!8nR>M+0mPB{Q+b7#CPnkc=v~)Vk@yTr)w7G7 zk8?jPSosSr}wv5C#@y)zQ6(hJb*Y+k>DrWQ&qYOy=yd4 z^=OF_-GkE#b||UH9F3pYXNqFIGRco}uxp8#>^)MM`_f8#^cTPW3RzVltrAP)-TB^D zeGQ$?oUwcmdU$<^|Ajo0bVmD0!Exd7BClZJ$+Y|MI{t3dt%8&Z)uN;7lJ96MuFSX< zT$+@+tbed&+)5s&?7N+Ey|^Q~_O(2{Y8^-(AO_r~^vHHnWIDZbL%%v})X%Rut)Y zp3Chx1jWQD;z!8LtW}z7+>!63<7?QmVszhtvvXU2m3W8E@XmVZervJKyu)gNY1T23 zJe9gd`5$lAWMif|th{g`rn@{?<6ZOnDXMYJ&9XPdha7&p2+#evGmoHPf+O~2-&nRe zim_Vn7fLDb?RN1`3Dd2;?rKw`8BhbiT*{m#I3Kf`-E?u;O6^sCs)uqxoEb{d6MYj_ zsH&(GQuz5=i*)Gj>g?*k!A5J6jf(0IjzkXW>)8Ac43>PlesP;@q-9QaTy^R`W|F~D z9b4aEA?Qb$T-!@LMl@Pi&3k_BHC^;;f9RA85&{M`CLYY~TU)vbr9(v?Ze4B6#wZr! zXuRr|w-)siK8}9+nSrQ6A&KVog7b!(TBb;YUE5i92KSKQz0MDzYNi+?(pdKlY0SDn zWwC>~&7io@w8^Q#w*u-r8jCwOba1Bm$ZYlxOLGqeqA~x27Yz$Hh3Xft>6#*2mTpEU zT9_J{?34@zSo1|fXcAa9E^KdVL?V@_?(TP+&G59y7U~2Qm&abgj_~o{3Z{7|f)64@MYSo>`b-zSBle1F4u>$qIW}|ONgRm`_ho^4$lYR07TbfW8eQ!%?$JA( zF1T8={h2$EA>%C|kXl@&8ur_)5R;{7DCC)HMQ^_vx9AtEpstd3k8<;O3eQ#r62umn zLpA#4zkN|}O>xZiv})zU$9D%3iWTBIT;;tHD6lp$-@kk!HC@MA=PH5tx$HsOxQ|ql z>+rX-bOlh-c-@RZoW9d-p^gN%4Jz=eM-_^6WUpxldUVCxoIhMrEUFQV4S7<`i{O(1 z>b}(Q`B~KY(bZLa=DQ=+mRgqQVD^lLs~{tFXK;jRJ%1h7SFKy1O?r-h2?Q|px2Fh-JOFRGnY+KcEf zRkv#W|LBB}=OJDF+zYyfb4q(1bAMj_m7&LB_J5ohs6eU-!YSPIW2dd(sz$uo8;f%& zVuy9ii6I!0dtQ}1;J@V2mc*>zbn$6!2N%wh_FrMxmbO;r|7M_$f#{kpoQkJd;LP)V zdOsuoPMZHmTRQ(j^>5k#n{eD=K=1m0a~QAX7ESukvIFqx_vnXxn~@2p2sGR6CpS)c z(R~R_k!)L23C#*&go4H;-S_MyY;=0@!pXMx7ge0k*K&uF@sY;|#cPslXa=5HU8}tCQw`3L0>g%jdjHC9%L1ng5(HW0FStwV zfz7a$|D@;mtLDpq{l~t)e6IZ?aANEgdP!+{d;#0y>!d}4&HGffG+nUFk=6XBdoJn_ zt=#e7bf41|5Zl*Q_a3Ze#(GQEvq4$7C1Nfk(h{I@PL4=jPZ+lhZOkQ1c0ZhiP`Izn zE;TQz=Y$3W-&4%h;zb2w+4 zFc*yPmBvD*BaFtuaA(NdCOiq!^cQZQuAwe9B4iw#hL6;!7L0E;N9NWPTTwMsecseY zlwc>|9=XNkU~9A>J;{jWj7_t%Zsj1nub(&t71Sh@Ly?B;JcT>jQfMk_7V8$NM>G>~ zGnnScWH#Mf&$|8CJ=aoA)IGnH!6g|sk~-~+;fGMH6E6`2yWfP$apYsRuC2J8B^rCQ z&%x-USpA!Cx%;DAWLf%eZl2VDW;`Sv*~Mq7o~2uRl@Y72sh?o?lK#lI#%9^pU;~U6 zjeWo&VdbZ|^s=Z5d}Mt8FX{LWW+7XiYl}|dXho5iZI9Lww6Ukgu<+^OnCUx?fm49T ziOB^*H&)WvrxiS+T2#w?`=jJHMnRRNA%00z@aomP%k;BBj20b+%f{4J#k4G-;;&OZi;F1xG%k2Wz1AE`@wbONX!4o9O| zF7rG;2zIIdSrUei{8FGV<9rFMUBqbVe26jVK;ft^0{YZ3S|VlwH6M8pQ&0YJL{r=G zM}g19z%J`Qb2F~pafHr=KmCtMmk5i+6*w(`HLZUH@U@S4RM;ViQ~l3nsXp#IqSkHl zpoW$u>mSBPZs(9%2TqF`|R6&L(*$X2IcGvJarQUorW`HAnl*9Uap2agN1Bagc5DdbPZyyiq0TS>Ap zpM~?gn=Zou%0hCk7Hz7hm1`QgKju_~y3K5<3g|>Q799>U+Yx0^N;py}*{h4p8IFNT9gM2Ca zXJZl#tC+JZ{cp*>i|OnE>(xC{XEjiXiB0)Lxu`pgSe)bi>q<$C=WRu>c@xvNA7WjNB{n@wbSX8VA2I*~B^3gd~ zfEo}1H>7GAGTS(tUC;nSF%*?izwtR z$(ZM>;WGK8g)vEwq6JcHmZ~Rp^|cg2()iTXXywu%sjXvWP)STmrm9>RJB*G4OMmz`8{8kVlXW#jkKAjD{LG9%PnFx28r zHhWc1zj~(io0>+aj{6EGs*n!^vy;vR#?b}5ZfX^4;wl97NKKCR$VeY;I=(_2G(m^3^QRdw+IG?$lu2a^dfoLlG3Qe z{c!_QxwCHwaA1=Mu1@WGuuQd65QMdSIGR)UVYKx&YO%sg#r03^ z+i`9MXg|p5?&y^7VS&@Lj50VsTA#^+%kTYFbiK)HPh^75(E2lC1tKg;g>E8J*vY)B zw$6@nSGVhW`@YP`{uU*Kxs&7B;PzB#<-Wx^j4bf>Xr?}fDh>&+HR1ac>snm0(U2o6 zjdEBLV+(VC#xxQYZgUU^Ll54?(!pH6{8&Co>aeA%;&CwYI(I<5E}`#9N@S^dHr?f# z6Hh1%*GPM~nHvAQ*gUjTlt~!>rW(F8)7r4&=m4Zk2)qa^Q6w1?Oy|NvSh# zrd)&JT_8%I`CId68%%0SgWFpd#|Sh|=pK|+)M4iN_>*E@*uiY($x0ee-@94>G4!Ja;SCJ3BtPm& zBmTsMtADO)v*Kk#BbX5JLUtFKJ4nG9*28jkJQ8`YprL)k*$V!(T9zqi%q#uBfy5F~ zHm}kK5}|>ky-SUb#4JMdZIB_0M%~rEHe-ufY(pP|WQi=vJ1fYmwYV4z4W(7!)ul4V`v;O`RzB?o@FCa0WKVK2erXa_L@LSN)p zb5dv}_O2G7EsYQMhx>Y~RQ7+8M*Gi*kyTc3pxWr= zQSlAux;A#F4(dVb_>a&)2L#{QySKQ;?&xp^YWzGkpbC;=mcBXhw=1}PrnjC*$+i@MWU$Qk7qP~ckuPW8)= zgH)?Cc#05Vn4xvu2H++{>r%4$y2m5X;KOP@6Vrb7Y@$s>5R>Q(coEo|$G z>#sGN-ziS@(!rdtKMniC!4PSPie)tgge!98lY!%S8eh^IIcKQeI)f5~3&YI)g3Cxi z7A8Dl>jtytx$&Q?OU&YGV8*7FTp1}>usENBNE(BGkkAQd*sZp;2@e1{Xj_E=)n8H^ zt`(a+Y~*_^-2iM;iw=lr_@d!7Cs#SJAV{DWhHdXJY-A+l55^&z>JOC(l~$n_`n@C> zXEdnI;gebzW<%Iqbv;3#KLM#LS;?@h=bjroP{c}2C~ycSlsY{~B2G;^)K6+Af}ev% z7_TG(9jH#gPY4;K!0j~r#9zj!rN;?hk&>+BLGP6NusRV5Pkk2y3#eDLF#o|%nExL3 z!X#H!5)e2(y!q2p1r5tc=(P?A@SApH}W& zi3JM8YGTU8>`+Z+nZp0uys03hDxj6S!sr47(jaiaNoVM;Yn7KnK>C$QuJc+29S9I5 zFpf#ZxClsY=E3dvypg`^x+QSjQWM~;8L%JHLY+^QAaHPrpR9erLD6twOWedIKM|i7 z*gGWdUI&_nb}&M(w*>Wotw(sOAXE^~AhG6ZXAL#%(XrB72@|*?>4nq}FHBMm0pXtS3VwlOjKwC;U{63De`eQInbrek9wa}d6ExsPm zsMr4d%qgENSj~zAmWjq=L|=Gjb0NyV9T|+^tb0) zziDC6O*@e8lwzwqwCT^;UMCb^h_`lV*Vj!y06>8#q1S@F)xK6K@=)vOZf1gX=ZTpV zrWuK(SS*>V>0;zzyv@R>zc!zNMt!$kt?GxWJ=N<{_Y>1L{oGrtV0>AE_*-9?qF>(s zzseNR^w+}Z?u=k*0j8N|?;i5^m~3{za8P~RBYb3BKp-ih!yKWvK!0cb+}Ut6 z(^B(;(x0o-NwFVO&Rm?((5{bo>lC80!6QzE0#Ud!}G|X~N|cAsdn3 zJwUDtX?<<87@DJZ%Ivk?cqYtRy~Ih3mX_+i_j5CO$1LNYS#!)DKW-DPZ@k>hK8j_q zy}vkA6@5F{Vt;FpXkIGsHCNeH^awuqy+$WXVbX6&D}6TeN9ZfJXPHIt5Mfb61Emoc zP#zz0hNJtAPL|Sd*_rwP_qEL|eB_9(dBwuKuK1g=X}7~}6+U}$zV%p6xq{BzWLO6u zH?i6bVTY5fvJ6D@-s~0pbcw%CsjY83E&RzltD!~N{{lwq)r9=pDfxw1d0XS3)_ML% z?Ibglq6~YhD_!$@CO*JY{PDTFFI;@s`KR;5-VD@)W@~T@qZiEjOk8F|r$*{f4JRx07=KaaG&OqgTs{djPgH4t?tMz}1Omg)a`x7u|)bnu3h9k~~D;$j%Fs{#%VcruFe!FYLV<%}|A?Z@QUv|@VLxg}} zDnIbDU`6n@%l-Rn7G{N;W2vp81ca}BMZmjY;hdeT@_>*~} zfa_y<&j}zE-hBM8g|0ClHWKVHIyy(7@j_xflj|H^B*#O^Wfm82z_k%}yO$I0$Wc<< zL<&BCN@B)E1`PhfXciY+8YYifJdVNBG2nZftlmuao&xP&7e9Z1B@ass0HjM%r#KE9aOOLq&l^zC7jCNC-P|hu2to`|VH8_ zP0W*UdTemL8XgGsuPXS?ySz(C79>n(rnu{wl2SL-HXrylK9IRPImtSWUf0HuDo)D3 ze54bokI*K!uU@b%7A8!aotA&a3bWfQ8Rb`(Z0>FEIH```lXR*+@8n=e=uCHbZQ&(4 z>yg??sl7nGZZ7H4&%V`7fpwg>^UE|dk64scYS&daz0h5~@YEtdQmb<;;k~j$;bkv9 zebuv+br!O>!*6-4N9@tW?x|6~c0z9(wFk_t8`~lg-|O;JrN@jg9nO6Xnnz;9c7MPN zDks*F*54t#?O$Yq0PE0)% zm>t53k?)V^B}J9e28>u=F{*TdWd&~3>0lk7W9>wEYNAa2XTQ%WG^}w;NUujuHG2@I zabAv0Y|q&&Wg*`uZKpkGkCm`|uJzPj@*`7_6<;1xJ1YP=j zcv=Du+pTR(*r_<6nj}_+pY5Ea973I^^;cT4*YT={q_vnQM zhGg|yo?%;N1BkP(*9gMIJf+)0fRzmu!{xn>9(KB%tN@xWcV$u=&bS5b2Y*jZ&e5*; zr)kxxy;byhqvODW5A2JopJUAwq!JcY#8U~WJrw|Idp)07z&1^yX8wPyV|vt-)s-h` zH&@Q{H_1t4o`AFd*Ccy+Iq+}lzkBL9Z*0-!gIG{IM$aAwkr4o=x-F=c7gatwyg_9! zFs3f?{Y~lozpXo*d;bPv&|&e)e-hN*c!s_S0XIk-${bLoa)B2Lw?2gCIDJjGJ-AAL z^WP>C?g)g!RdW>yRB@>S;AG)|1a$@2RwTtfJ09P0;DSsR&Sj%d4<4sJ1f+8RhLFU^ zSFm^vjf*CHekt%h4cJUbO!DZ%q2YS3W3W`nHlaM%>NNNIqZRVfv%M+he-#vI`wZRp z5HCv0T)WKaWg)eQQ-Lc+3cU&YPAg~U_bIN~oRq4-@D#=-Dk8sl(kyP}Oslq?Buy@r zPUKA(8?D@^U(>jgYs{xQ+Dd96YGjJEWWZ{!+-`46VarfpDuCHxf;DJ9Zrd{$frscGrF;w zoRVV!kDlZ6l-xhv;o{bH92%ZLKl-@}lMbj?q8rm9l}GF8pJ9Uj`}v|Gp94*$9FJSg z;UfOP0f%@-Z0&d+cT_2^O&2w5Y$M;dqS1T!$$ra^)Tk3yzLdRB0mYuoK#Y8R>N44Y%{b7YX#Ax z#`S8j562t zGpda;`H{Jp_Ee+3mwV}ZFYjn;kC^=L`3?_IEvbX4I-Cd7fx3O^0Y~0dPOw07!ssyz zhw(n?p4on0u1>o!@U(2xt1u0}u!ZsoSkTOske?jrFb3wLK^W(US!1)T67zjDo|y`a z+Kz(v3$e<>nsn<79z6Ly^Nh2}9crLY)D)9|K5JGufQDE`9-I``#w2jaBs0ub>N2{l z9-NAE<1e@}`@C0{P_OWny~oW7<+92MHkJ#z`O_uLHb&4DJ}I8sHLK}BlwJR9Q5+bOLU7*hWG2G%?h4Uf68FK|2 zZ$XR@(vbKKt4$F5(nr6``o2Pm0^xv_aw{#sXI|8h82lx37Q_mDU=nfD=GyMokf0`g zIB>z-*yyFS%~7&uA%80q&RCmCeB-EF!7=&@_1OezB+p=d zyJwLIVJwYQc=UT6F5|Y!KgE3X;HCuDuhjatbiEj%&~?I29Kosc8#p|la_~PxIAE^Y zdIa)jmUf@DJ5)?pvGk?Ak4EYNUwG_l0GCoAjQnc>jAw>=Nh#M4VZ~D51qjgnZi0_1 zsw!3FCTqIzw4!m3);O@jGk0RCZ{|dPcPNR==`HnVBb$l0IdmX6FBXp`T_ZZIoFqha zSI4S=7t;zOWz)gpnN_N=#cLKX=xp~oZb*wVReu!A&gcbj#!igyGGVF-ImrrazZ~`J za(dSjo~j1}np9cy)6*M5p93JI2Tg{pQ!Wy#a!1L^cCfNvDaz|P`@!c346#=o6W*)D zdCpV2aZM;yYmeu~nv_pWa?Uk&pvgqo)uzv$JnH~k4z_|~JqdA~U9So-UXRjfftCEs zgSoD$;(bUAnsi?oKM&Tppsof(5McUJ%d4#q@7`)=>{jS#aB?yQ7B0AL7{_oKoz1dX z+k;a+Fkq@X8Piy*L*Re_=yx2ImJXJ(i?Ggr5=;i);+<~?sYd;qF8;ecxf?@ShBwKD z)NpJ*`Um-fno)ZEst9n8|6CyFHru6U@*Y-|!zL-#c@e&?3eM|X(k{}p4gY`W Cr0IMB literal 0 HcmV?d00001 diff --git a/observables.py b/observables.py new file mode 100644 index 0000000..134dd37 --- /dev/null +++ b/observables.py @@ -0,0 +1,69 @@ +from gym import spaces +import numpy as np +import pygame + + +class Observable(): + def __init__(self): + self.obs = None + pass + + def get_observation_space(): + print("[!] Using dummyObservable. Env won't output anything") + return spaces.Box(low=0, high=255, + shape=(1,), dtype=np.uint8) + + +class CnnObservable(Observable): + def __init__(self, in_width=256, in_height=256, out_width=32, out_height=32, draw_width=128, draw_height=128, smooth_scaling=True): + super(CnnObservable, self).__init__() + self.in_width = in_width + self.in_height = in_height + self.out_width = out_width + self.out_height = out_height + self.draw_width = draw_width + self.draw_height = draw_height + if smooth_scaling: + self.scaler = pygame.transform.smoothscale + else: + self.scaler = pygame.transform.scale + + def _set_env(self, env): + self.env = env + + def get_observation_space(self): + return spaces.Box(low=0, high=255, + shape=(self.out_width, self.out_height), dtype=np.uint8) + + def get_observation(self): + x, y = self.env.agent.pos[0]*self.env.width - self.in_width / \ + 2, self.env.agent.pos[1]*self.env.height - self.in_height/2 + w, h = self.in_width, self.in_height + cx, cy = _clip(x, 0, self.env.width), _clip( + y, 0, self.env.height) + cw, ch = _clip(w, 0, self.env.width - cx), _clip(h, + 0, self.env.height - cy) + rect = pygame.Rect(cx, cy, cw, ch) + snap = self.env.surface.subsurface(rect) + self.snap = pygame.Surface((self.in_width, self.in_height)) + pygame.draw.rect(self.snap, (50, 50, 50), + pygame.Rect(0, 0, self.in_width, self.in_height)) + self.snap.blit(snap, (cx - x, cy - y)) + self.obs = self.scaler( + self.snap, (self.out_width, self.out_height)) + return self.obs + + def draw(self): + if not self.obs: + self.get_observation() + big = pygame.transform.scale( + self.obs, (self.draw_width, self.draw_height)) + x, y = self.env.width - self.draw_width - 10, 10 + pygame.draw.rect(self.env.screen, (50, 50, 50), + pygame.Rect(x - 1, y - 1, self.draw_width + 2, self.draw_height + 2)) + self.env.screen.blit( + big, (x, y)) + + +def _clip(num, lower, upper): + return min(max(num, lower), upper)