Lets test Perlin and PCA on Perlin
This commit is contained in:
parent
d04f245e9b
commit
c0913ba965
65
test.py
65
test.py
@ -9,6 +9,7 @@ from columbus import env
|
|||||||
from columbus.observables import Observable, CnnObservable
|
from columbus.observables import Observable, CnnObservable
|
||||||
|
|
||||||
import colorednoise as cn
|
import colorednoise as cn
|
||||||
|
from perlin_noise import PerlinNoise
|
||||||
|
|
||||||
from priorConditionedAnnealing import pca
|
from priorConditionedAnnealing import pca
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ def getAvaibleEnvs():
|
|||||||
yield getattr(env, s)
|
yield getattr(env, s)
|
||||||
|
|
||||||
|
|
||||||
def loadConfigDefinedEnv(EnvClass):
|
def loadConfigDefinedEnv(EnvClass, alg_name):
|
||||||
p = input('[Path to config> ')
|
p = input('[Path to config> ')
|
||||||
with open(p, 'r') as f:
|
with open(p, 'r') as f:
|
||||||
docs = list([d for d in yaml.safe_load_all(
|
docs = list([d for d in yaml.safe_load_all(
|
||||||
@ -57,7 +58,7 @@ def loadConfigDefinedEnv(EnvClass):
|
|||||||
print('Unable to find key "'+key+'"')
|
print('Unable to find key "'+key+'"')
|
||||||
path = input('[Path> ')
|
path = input('[Path> ')
|
||||||
print(cur)
|
print(cur)
|
||||||
return EnvClass(fps=30, **cur)
|
return EnvClass(fps=30, title_appendix=' ['+alg_name+']', **cur)
|
||||||
|
|
||||||
|
|
||||||
def chooseEnv(alg_name):
|
def chooseEnv(alg_name):
|
||||||
@ -74,7 +75,7 @@ def chooseEnv(alg_name):
|
|||||||
print(
|
print(
|
||||||
'[!] That is a number, but not one that makes sense in this context...')
|
'[!] That is a number, but not one that makes sense in this context...')
|
||||||
if envs[i] in [env.ColumbusConfigDefined]:
|
if envs[i] in [env.ColumbusConfigDefined]:
|
||||||
return loadConfigDefinedEnv(envs[i])
|
return loadConfigDefinedEnv(envs[i], alg_name)
|
||||||
Env = envs[i]
|
Env = envs[i]
|
||||||
return Env(fps=30, agent_draw_path=True, path_decay=1/1024, title_appendix=' ['+alg_name+']', max_steps=30*10, clear_path_on_reset=False)
|
return Env(fps=30, agent_draw_path=True, path_decay=1/1024, title_appendix=' ['+alg_name+']', max_steps=30*10, clear_path_on_reset=False)
|
||||||
|
|
||||||
@ -111,12 +112,62 @@ class Colored_Noise():
|
|||||||
self.beta, (self.dim_a, self.samples), random_state=rand_seed())
|
self.beta, (self.dim_a, self.samples), random_state=rand_seed())
|
||||||
|
|
||||||
|
|
||||||
class PCA_Noise():
|
class Perlin_Noise():
|
||||||
def __init__(self, dim_a=2, kernel_func='SE_1.41_1', window=64, ssf=-1):
|
def __init__(self, scale=0.05, octaves=1, dim_a=2):
|
||||||
|
self.scale = scale
|
||||||
|
self.octaves = octaves
|
||||||
|
self.dim_a = dim_a
|
||||||
|
self.magic = 3.14159 # Axis offset
|
||||||
|
# We want to genrate samples, that approx ~N(0,1)
|
||||||
|
self.normal_factor = 0.0471
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self, obs, env):
|
||||||
|
self.index += 1
|
||||||
|
return [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor
|
||||||
|
for a in range(self.dim_a)]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.index = 0
|
||||||
|
self.noise = PerlinNoise(octaves=self.octaves, seed=rand_seed())
|
||||||
|
|
||||||
|
|
||||||
|
class Perlin_PCA_Noise():
|
||||||
|
def __init__(self, dim_a=2, kernel_func='SE_1.41_1', window=64, ssf=-1, f_sigma=1):
|
||||||
self.dim_a = dim_a
|
self.dim_a = dim_a
|
||||||
self.kernel_func = kernel_func
|
self.kernel_func = kernel_func
|
||||||
self.window = window
|
self.window = window
|
||||||
self.ssf = ssf
|
self.ssf = ssf
|
||||||
|
self.f_sigma = f_sigma
|
||||||
|
self.index = 0
|
||||||
|
self.perlin = Perlin_Noise()
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self, obs, env):
|
||||||
|
if self.ssf != -1 and self.index % self.ssf == 0:
|
||||||
|
self.traj = [[0]*len(self.traj[0])]
|
||||||
|
traj = th.Tensor(self.traj).unsqueeze(0)
|
||||||
|
eps = th.Tensor(self.perlin(None, None)).unsqueeze(0)
|
||||||
|
sample = self.dist.sample(traj, self.f_sigma, epsilon=eps).squeeze(0)
|
||||||
|
self.traj.append(sample)
|
||||||
|
self.index += 1
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.dist = pca.PCA_Distribution(
|
||||||
|
action_dim=self.dim_a, par_strength='CONT_DIAG', kernel_func=self.kernel_func, window=self.window)
|
||||||
|
self.dist.proba_distribution(th.Tensor([[0]*2]), th.Tensor([[1]*2]))
|
||||||
|
self.traj = [[0]*self.dim_a]
|
||||||
|
self.perlin.reset()
|
||||||
|
|
||||||
|
|
||||||
|
class PCA_Noise():
|
||||||
|
def __init__(self, dim_a=2, kernel_func='SE_1.41_1', window=64, ssf=-1, f_sigma=1):
|
||||||
|
self.dim_a = dim_a
|
||||||
|
self.kernel_func = kernel_func
|
||||||
|
self.window = window
|
||||||
|
self.ssf = ssf
|
||||||
|
self.f_sigma = f_sigma
|
||||||
self.index = 0
|
self.index = 0
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -124,7 +175,7 @@ class PCA_Noise():
|
|||||||
if self.ssf != -1 and self.index % self.ssf == 0:
|
if self.ssf != -1 and self.index % self.ssf == 0:
|
||||||
self.traj = [[0]*len(self.traj[0])]
|
self.traj = [[0]*len(self.traj[0])]
|
||||||
traj = th.Tensor(self.traj).unsqueeze(0)
|
traj = th.Tensor(self.traj).unsqueeze(0)
|
||||||
sample = self.dist.sample(traj).squeeze(0)
|
sample = self.dist.sample(traj, self.f_sigma).squeeze(0)
|
||||||
self.traj.append(sample)
|
self.traj.append(sample)
|
||||||
self.index += 1
|
self.index += 1
|
||||||
return sample
|
return sample
|
||||||
@ -172,7 +223,7 @@ def rand_seed():
|
|||||||
|
|
||||||
def choosePlayType():
|
def choosePlayType():
|
||||||
options = {'human': human_input, 'PCA': PCA_Noise(),
|
options = {'human': human_input, 'PCA': PCA_Noise(),
|
||||||
'REX': Colored_Noise(beta=0), 'PINK': Colored_Noise(beta=1), 'BROWN': Colored_Noise(beta=2), 'BETA.5': Colored_Noise(beta=.5), 'PINK_PCA': Colored_PCA_Noise(beta=1)}
|
'REX': Colored_Noise(beta=0), 'PINK': Colored_Noise(beta=1), 'BROWN': Colored_Noise(beta=2), 'BETA.5': Colored_Noise(beta=.5), 'PINK_PCA': Colored_PCA_Noise(beta=1), 'Precise_PCA': PCA_Noise(f_sigma=0.33), 'Perlin': Perlin_Noise(scale=0.05, octaves=1), 'FastPerlin': Perlin_Noise(scale=0.2, octaves=1), 'SlowPerlin': Perlin_Noise(scale=0.0125, octaves=1), 'Perlin_3': Perlin_Noise(scale=0.05, octaves=3), 'Perlin_8': Perlin_Noise(scale=0.05, octaves=8), 'Perlin_PCA': Perlin_PCA_Noise()}
|
||||||
for i, name in enumerate(options):
|
for i, name in enumerate(options):
|
||||||
print('['+str(i)+'] '+name)
|
print('['+str(i)+'] '+name)
|
||||||
while True:
|
while True:
|
||||||
|
Loading…
Reference in New Issue
Block a user