PriorConditionedAnnealing/test.py

161 lines
4.0 KiB
Python
Raw Normal View History

import torch as th
from time import sleep, time
import numpy as np
import pygame
import yaml
from columbus import env
from columbus.observables import Observable, CnnObservable
import colorednoise as cn
from pca import *
def main():
agent_func = choosePlayType()
env = chooseEnv()
while True:
playEnv(env, agent_func)
input('<again?>')
env.close()
def getAvaibleEnvs():
# kinda hacky... idk
strs = dir(env)
for s in strs:
if s.startswith('Columbus') and s != 'ColumbusEnv':
yield getattr(env, s)
def loadConfigDefinedEnv(EnvClass):
p = input('[Path to config> ')
with open(p, 'r') as f:
docs = list([d for d in yaml.safe_load_all(
f) if d and 'name' in d and d['name'] not in ['SLURM']])
for i, doc in enumerate(docs):
name = doc['name']
print('['+str(i)+'] '+name)
ds = int(input('[0]> ') or '0')
doc = docs[ds]
cur = doc
path = 'params.task.env_args'
p = path.split('.')
while True:
try:
if len(p) == 0:
break
key = p.pop(0)
print(key)
cur = cur[key]
except Exception as e:
print('Unable to find key "'+key+'"')
path = input('[Path> ')
print(cur)
return EnvClass(fps=30, **cur)
def chooseEnv():
envs = list(getAvaibleEnvs())
for i, Env in enumerate(envs):
print('['+str(i)+'] '+Env.__name__)
while True:
inp = input('[#> ')
try:
i = int(inp)
except:
print('[!] You have to enter the number...')
if i < 0 or i >= len(envs):
print(
'[!] That is a number, but not one that makes sense in this context...')
if envs[i] in [env.ColumbusConfigDefined]:
return loadConfigDefinedEnv(envs[i])
Env = envs[i]
return Env(fps=30)
def value_func(obs):
return obs[:, 0]
# return th.rand(obs.shape[0])-0.5
def human_input(obs):
pos = (0.5, 0.5)
pos = pygame.mouse.get_pos()
pos = (min(max((pos[0]-env.joystick_offset[0]-20)/60, 0), 1),
min(max((pos[1]-env.joystick_offset[1]-20)/60, 0), 1))
pos = pos[0]*2-1, pos[1]*2-1
return pos
def colored_noise(beta=1, dim_a=2, samples=2**18):
index = [0]*dim_a
samples = []
for d in range(dim_a):
samples.append(cn.powerlaw_psd_gaussian(beta, samples))
def noise_generator(obs):
sample = []
for d in range(dim_a):
sample.append(samples[d][index])
index += 1
return sample
return noise_generator
def pca_noise(lengthscale=1, dim_a=2, kernel_func='RBF', window=16):
dist = PCA_Distribution(
action_dim=dim_a, par_strength='SCALAR', kernel_func=kernel_func, window=window)
traj = []
def noise_generator(obs):
sample = dist.sample(th.Tensor(traj))
traj.append(sample)
return sample
return noise_generator
def choosePlayType():
options = {'human': human_input, 'REX': None, 'PCA': None, 'PINK': None}
for i, name in enumerate(options):
print('['+str(i)+'] '+name)
while True:
inp = input('[#> ')
try:
i = int(inp)
except:
print('[!] You have to enter the number...')
if i < 0 or i >= len(options):
print(
'[!] That is a number, but not one that makes sense in this context...')
return options[name]
def playEnv(env, agent_func):
done = False
obs = env.reset()
while not done:
t1 = time()
# env.render(value_func=value_func)
env.render()
inp = agent_func(obs)
obs, rew, done, info = env.step(np.array(inp, dtype=np.float32))
print('Reward: '+str(rew))
print('Score: '+str(info))
t2 = time()
dt = t2 - t1
delay = (1/env.fps - dt)
if delay < 0:
print("[!] Can't keep framerate!")
else:
sleep(delay)
if __name__ == '__main__':
main()