Implement 'plausability tests' of PCA vs REX / PINK in Columbus envs
This commit is contained in:
parent
21de8f418b
commit
44b34fe12b
61
test.py
61
test.py
@ -9,7 +9,7 @@ from columbus.observables import Observable, CnnObservable
|
|||||||
|
|
||||||
import colorednoise as cn
|
import colorednoise as cn
|
||||||
|
|
||||||
from pca import *
|
from priorConditionedAnnealing import pca
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -72,7 +72,7 @@ def chooseEnv():
|
|||||||
if envs[i] in [env.ColumbusConfigDefined]:
|
if envs[i] in [env.ColumbusConfigDefined]:
|
||||||
return loadConfigDefinedEnv(envs[i])
|
return loadConfigDefinedEnv(envs[i])
|
||||||
Env = envs[i]
|
Env = envs[i]
|
||||||
return Env(fps=30)
|
return Env(fps=30, agent_draw_path=True, path_decay=1/256)
|
||||||
|
|
||||||
|
|
||||||
def value_func(obs):
|
def value_func(obs):
|
||||||
@ -80,7 +80,7 @@ def value_func(obs):
|
|||||||
# return th.rand(obs.shape[0])-0.5
|
# return th.rand(obs.shape[0])-0.5
|
||||||
|
|
||||||
|
|
||||||
def human_input(obs):
|
def human_input(obs, env):
|
||||||
pos = (0.5, 0.5)
|
pos = (0.5, 0.5)
|
||||||
pos = pygame.mouse.get_pos()
|
pos = pygame.mouse.get_pos()
|
||||||
pos = (min(max((pos[0]-env.joystick_offset[0]-20)/60, 0), 1),
|
pos = (min(max((pos[0]-env.joystick_offset[0]-20)/60, 0), 1),
|
||||||
@ -89,39 +89,40 @@ def human_input(obs):
|
|||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
|
||||||
def colored_noise(beta=1, dim_a=2, samples=2**18):
|
class Colored_Noise():
|
||||||
index = [0]*dim_a
|
def __init__(self, beta=1, dim_a=2, samples=2**18):
|
||||||
samples = []
|
self.index = 0
|
||||||
for d in range(dim_a):
|
self.samples = cn.powerlaw_psd_gaussian(beta, (dim_a, samples))
|
||||||
samples.append(cn.powerlaw_psd_gaussian(beta, samples))
|
|
||||||
|
|
||||||
def noise_generator(obs):
|
def __call__(self, obs, env):
|
||||||
sample = []
|
sample = self.samples[:, self.index]
|
||||||
for d in range(dim_a):
|
self.index += 1
|
||||||
sample.append(samples[d][index])
|
|
||||||
index += 1
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return noise_generator
|
|
||||||
|
|
||||||
|
class PCA_Noise():
|
||||||
|
def __init__(self, dim_a=2, kernel_func='SE_1.41_1', window=8, ssf=-1):
|
||||||
|
self.dist = pca.PCA_Distribution(
|
||||||
|
action_dim=dim_a, par_strength='CONT_DIAG', kernel_func=kernel_func, window=window)
|
||||||
|
self.dist.proba_distribution(th.Tensor([[0]*2]), th.Tensor([[1]*2]))
|
||||||
|
self.traj = [[0]*dim_a]
|
||||||
|
self.ssf = 32
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
def pca_noise(lengthscale=1, dim_a=2, kernel_func='RBF', window=16):
|
def __call__(self, obs, env):
|
||||||
|
if self.index == self.ssf:
|
||||||
dist = PCA_Distribution(
|
self.index = 0
|
||||||
action_dim=dim_a, par_strength='SCALAR', kernel_func=kernel_func, window=window)
|
self.traj = [[0]*len(self.traj[0])]
|
||||||
|
traj = th.Tensor(self.traj).unsqueeze(0)
|
||||||
traj = []
|
sample = self.dist.sample(traj).squeeze(0)
|
||||||
|
self.traj.append(sample)
|
||||||
def noise_generator(obs):
|
self.index += 1
|
||||||
sample = dist.sample(th.Tensor(traj))
|
|
||||||
traj.append(sample)
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
return noise_generator
|
|
||||||
|
|
||||||
|
|
||||||
def choosePlayType():
|
def choosePlayType():
|
||||||
options = {'human': human_input, 'REX': None, 'PCA': None, 'PINK': None}
|
options = {'human': human_input, 'PCA': PCA_Noise(),
|
||||||
|
'REX': Colored_Noise(beta=0), 'PINK': Colored_Noise(beta=1), 'BROWN': Colored_Noise(beta=2), 'BETA.5': Colored_Noise(beta=.5)}
|
||||||
for i, name in enumerate(options):
|
for i, name in enumerate(options):
|
||||||
print('['+str(i)+'] '+name)
|
print('['+str(i)+'] '+name)
|
||||||
while True:
|
while True:
|
||||||
@ -130,10 +131,12 @@ def choosePlayType():
|
|||||||
i = int(inp)
|
i = int(inp)
|
||||||
except:
|
except:
|
||||||
print('[!] You have to enter the number...')
|
print('[!] You have to enter the number...')
|
||||||
|
continue
|
||||||
if i < 0 or i >= len(options):
|
if i < 0 or i >= len(options):
|
||||||
print(
|
print(
|
||||||
'[!] That is a number, but not one that makes sense in this context...')
|
'[!] That is a number, but not one that makes sense in this context...')
|
||||||
return options[name]
|
else:
|
||||||
|
return options[list(options.keys())[i]]
|
||||||
|
|
||||||
|
|
||||||
def playEnv(env, agent_func):
|
def playEnv(env, agent_func):
|
||||||
@ -143,7 +146,7 @@ def playEnv(env, agent_func):
|
|||||||
t1 = time()
|
t1 = time()
|
||||||
# env.render(value_func=value_func)
|
# env.render(value_func=value_func)
|
||||||
env.render()
|
env.render()
|
||||||
inp = agent_func(obs)
|
inp = agent_func(obs, env)
|
||||||
obs, rew, done, info = env.step(np.array(inp, dtype=np.float32))
|
obs, rew, done, info = env.step(np.array(inp, dtype=np.float32))
|
||||||
print('Reward: '+str(rew))
|
print('Reward: '+str(rew))
|
||||||
print('Score: '+str(info))
|
print('Score: '+str(info))
|
||||||
|
Loading…
Reference in New Issue
Block a user