diff --git a/priorConditionedAnnealing/kernel.py b/priorConditionedAnnealing/kernel.py new file mode 100644 index 0000000..c5bdecb --- /dev/null +++ b/priorConditionedAnnealing/kernel.py @@ -0,0 +1,34 @@ +import numpy as np +import scipy + + +def rbf(l=1.0): + return se(sig=1.414, l=l) + + +def se(sig=1.414, l=1.0): + def func(xa, xb): + sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l + return (sig**2)*np.exp(sq) + return func + + +def brown(f=0.02, eps=0.00000001): + def func(xa, xb): + m = [] + for i in range(xa.shape[0]): + l = [] + for j in range(xb.shape[0]): + l.append(min(xa[i][0], xb[j][0])) + m.append(l) + + return (np.array(m) + eps)*f + return func + + +def pink(sig=1.414): + def func(xa, xb): + d = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') + m = (d == 1)*1 + (d == 0)*sig + return m + return func diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 02623de..e3d3348 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -8,7 +8,7 @@ from stable_baselines3.common.distributions import sum_independent_dims from torch.distributions import Normal import torch.nn.functional as F -from priorConditionedAnnealing import noise +from priorConditionedAnnealing import noise, kernel class Par_Strength(Enum): @@ -34,24 +34,15 @@ class EnforcePositiveType(Enum): return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x) -def rbf(l=1.0): - return se(sig=1.0, l=l) - - -def se(sig=1.0, l=1.0): - def func(xa, xb): - sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l - return (sig**2)*np.exp(sq) - return func - - class Avaible_Kernel_Funcs(Enum): RBF = 0 SE = 1 + BROWN = 2 + PINK = 3 def get_func(self): # stil aaaaaaaa - return [rbf, se][self.value] + return [kernel.rbf, kernel.se, kernel.brown, kernel.pink][self.value] def cast_to_enum(inp, Class): @@ -75,7 +66,7 @@ class PCA_Distribution(SB3_Distribution): self, action_dim: int, par_strength: Par_Strength = Par_Strength.CONT_DIAG, - kernel_func=rbf(), + kernel_func=kernel.rbf(), init_std: float = 1, cond_noise: float = 0, window: int = 64,