PriorConditionedAnnealing/test.py

164 lines
4.5 KiB
Python
Raw Normal View History

import torch as th
from time import sleep, time
import numpy as np
import pygame
import yaml
from columbus import env
from columbus.observables import Observable, CnnObservable
import colorednoise as cn
from priorConditionedAnnealing import pca
def main():
agent_func = choosePlayType()
env = chooseEnv()
while True:
playEnv(env, agent_func)
input('<again?>')
env.close()
def getAvaibleEnvs():
# kinda hacky... idk
strs = dir(env)
for s in strs:
if s.startswith('Columbus') and s != 'ColumbusEnv':
yield getattr(env, s)
def loadConfigDefinedEnv(EnvClass):
p = input('[Path to config> ')
with open(p, 'r') as f:
docs = list([d for d in yaml.safe_load_all(
f) if d and 'name' in d and d['name'] not in ['SLURM']])
for i, doc in enumerate(docs):
name = doc['name']
print('['+str(i)+'] '+name)
ds = int(input('[0]> ') or '0')
doc = docs[ds]
cur = doc
path = 'params.task.env_args'
p = path.split('.')
while True:
try:
if len(p) == 0:
break
key = p.pop(0)
print(key)
cur = cur[key]
except Exception as e:
print('Unable to find key "'+key+'"')
path = input('[Path> ')
print(cur)
return EnvClass(fps=30, **cur)
def chooseEnv():
envs = list(getAvaibleEnvs())
for i, Env in enumerate(envs):
print('['+str(i)+'] '+Env.__name__)
while True:
inp = input('[#> ')
try:
i = int(inp)
except:
print('[!] You have to enter the number...')
if i < 0 or i >= len(envs):
print(
'[!] That is a number, but not one that makes sense in this context...')
if envs[i] in [env.ColumbusConfigDefined]:
return loadConfigDefinedEnv(envs[i])
Env = envs[i]
return Env(fps=30, agent_draw_path=True, path_decay=1/256)
def value_func(obs):
return obs[:, 0]
# return th.rand(obs.shape[0])-0.5
def human_input(obs, env):
pos = (0.5, 0.5)
pos = pygame.mouse.get_pos()
pos = (min(max((pos[0]-env.joystick_offset[0]-20)/60, 0), 1),
min(max((pos[1]-env.joystick_offset[1]-20)/60, 0), 1))
pos = pos[0]*2-1, pos[1]*2-1
return pos
class Colored_Noise():
def __init__(self, beta=1, dim_a=2, samples=2**18):
self.index = 0
self.samples = cn.powerlaw_psd_gaussian(beta, (dim_a, samples))
def __call__(self, obs, env):
sample = self.samples[:, self.index]
self.index += 1
return sample
class PCA_Noise():
def __init__(self, dim_a=2, kernel_func='SE_1.41_1', window=8, ssf=-1):
self.dist = pca.PCA_Distribution(
action_dim=dim_a, par_strength='CONT_DIAG', kernel_func=kernel_func, window=window)
self.dist.proba_distribution(th.Tensor([[0]*2]), th.Tensor([[1]*2]))
self.traj = [[0]*dim_a]
self.ssf = 32
self.index = 0
def __call__(self, obs, env):
if self.index == self.ssf:
self.index = 0
self.traj = [[0]*len(self.traj[0])]
traj = th.Tensor(self.traj).unsqueeze(0)
sample = self.dist.sample(traj).squeeze(0)
self.traj.append(sample)
self.index += 1
return sample
def choosePlayType():
options = {'human': human_input, 'PCA': PCA_Noise(),
'REX': Colored_Noise(beta=0), 'PINK': Colored_Noise(beta=1), 'BROWN': Colored_Noise(beta=2), 'BETA.5': Colored_Noise(beta=.5)}
for i, name in enumerate(options):
print('['+str(i)+'] '+name)
while True:
inp = input('[#> ')
try:
i = int(inp)
except:
print('[!] You have to enter the number...')
continue
if i < 0 or i >= len(options):
print(
'[!] That is a number, but not one that makes sense in this context...')
else:
return options[list(options.keys())[i]]
def playEnv(env, agent_func):
done = False
obs = env.reset()
while not done:
t1 = time()
# env.render(value_func=value_func)
env.render()
inp = agent_func(obs, env)
obs, rew, done, info = env.step(np.array(inp, dtype=np.float32))
print('Reward: '+str(rew))
print('Score: '+str(info))
t2 = time()
dt = t2 - t1
delay = (1/env.fps - dt)
if delay < 0:
print("[!] Can't keep framerate!")
else:
sleep(delay)
if __name__ == '__main__':
main()