refactored kernels into seperate file
This commit is contained in:
parent
255c8c379b
commit
e0471dbfb7
34
priorConditionedAnnealing/kernel.py
Normal file
34
priorConditionedAnnealing/kernel.py
Normal file
@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
|
||||
def rbf(l=1.0):
|
||||
return se(sig=1.414, l=l)
|
||||
|
||||
|
||||
def se(sig=1.414, l=1.0):
|
||||
def func(xa, xb):
|
||||
sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l
|
||||
return (sig**2)*np.exp(sq)
|
||||
return func
|
||||
|
||||
|
||||
def brown(f=0.02, eps=0.00000001):
|
||||
def func(xa, xb):
|
||||
m = []
|
||||
for i in range(xa.shape[0]):
|
||||
l = []
|
||||
for j in range(xb.shape[0]):
|
||||
l.append(min(xa[i][0], xb[j][0]))
|
||||
m.append(l)
|
||||
|
||||
return (np.array(m) + eps)*f
|
||||
return func
|
||||
|
||||
|
||||
def pink(sig=1.414):
|
||||
def func(xa, xb):
|
||||
d = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean')
|
||||
m = (d == 1)*1 + (d == 0)*sig
|
||||
return m
|
||||
return func
|
@ -8,7 +8,7 @@ from stable_baselines3.common.distributions import sum_independent_dims
|
||||
from torch.distributions import Normal
|
||||
import torch.nn.functional as F
|
||||
|
||||
from priorConditionedAnnealing import noise
|
||||
from priorConditionedAnnealing import noise, kernel
|
||||
|
||||
|
||||
class Par_Strength(Enum):
|
||||
@ -34,24 +34,15 @@ class EnforcePositiveType(Enum):
|
||||
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
||||
|
||||
|
||||
def rbf(l=1.0):
|
||||
return se(sig=1.0, l=l)
|
||||
|
||||
|
||||
def se(sig=1.0, l=1.0):
|
||||
def func(xa, xb):
|
||||
sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l
|
||||
return (sig**2)*np.exp(sq)
|
||||
return func
|
||||
|
||||
|
||||
class Avaible_Kernel_Funcs(Enum):
|
||||
RBF = 0
|
||||
SE = 1
|
||||
BROWN = 2
|
||||
PINK = 3
|
||||
|
||||
def get_func(self):
|
||||
# stil aaaaaaaa
|
||||
return [rbf, se][self.value]
|
||||
return [kernel.rbf, kernel.se, kernel.brown, kernel.pink][self.value]
|
||||
|
||||
|
||||
def cast_to_enum(inp, Class):
|
||||
@ -75,7 +66,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
self,
|
||||
action_dim: int,
|
||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||
kernel_func=rbf(),
|
||||
kernel_func=kernel.rbf(),
|
||||
init_std: float = 1,
|
||||
cond_noise: float = 0,
|
||||
window: int = 64,
|
||||
|
Loading…
Reference in New Issue
Block a user