refactored kernels into seperate file
This commit is contained in:
parent
255c8c379b
commit
e0471dbfb7
34
priorConditionedAnnealing/kernel.py
Normal file
34
priorConditionedAnnealing/kernel.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
|
||||||
|
|
||||||
|
def rbf(l=1.0):
|
||||||
|
return se(sig=1.414, l=l)
|
||||||
|
|
||||||
|
|
||||||
|
def se(sig=1.414, l=1.0):
|
||||||
|
def func(xa, xb):
|
||||||
|
sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l
|
||||||
|
return (sig**2)*np.exp(sq)
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def brown(f=0.02, eps=0.00000001):
|
||||||
|
def func(xa, xb):
|
||||||
|
m = []
|
||||||
|
for i in range(xa.shape[0]):
|
||||||
|
l = []
|
||||||
|
for j in range(xb.shape[0]):
|
||||||
|
l.append(min(xa[i][0], xb[j][0]))
|
||||||
|
m.append(l)
|
||||||
|
|
||||||
|
return (np.array(m) + eps)*f
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def pink(sig=1.414):
|
||||||
|
def func(xa, xb):
|
||||||
|
d = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean')
|
||||||
|
m = (d == 1)*1 + (d == 0)*sig
|
||||||
|
return m
|
||||||
|
return func
|
@ -8,7 +8,7 @@ from stable_baselines3.common.distributions import sum_independent_dims
|
|||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from priorConditionedAnnealing import noise
|
from priorConditionedAnnealing import noise, kernel
|
||||||
|
|
||||||
|
|
||||||
class Par_Strength(Enum):
|
class Par_Strength(Enum):
|
||||||
@ -34,24 +34,15 @@ class EnforcePositiveType(Enum):
|
|||||||
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
|
||||||
|
|
||||||
|
|
||||||
def rbf(l=1.0):
|
|
||||||
return se(sig=1.0, l=l)
|
|
||||||
|
|
||||||
|
|
||||||
def se(sig=1.0, l=1.0):
|
|
||||||
def func(xa, xb):
|
|
||||||
sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l
|
|
||||||
return (sig**2)*np.exp(sq)
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class Avaible_Kernel_Funcs(Enum):
|
class Avaible_Kernel_Funcs(Enum):
|
||||||
RBF = 0
|
RBF = 0
|
||||||
SE = 1
|
SE = 1
|
||||||
|
BROWN = 2
|
||||||
|
PINK = 3
|
||||||
|
|
||||||
def get_func(self):
|
def get_func(self):
|
||||||
# stil aaaaaaaa
|
# stil aaaaaaaa
|
||||||
return [rbf, se][self.value]
|
return [kernel.rbf, kernel.se, kernel.brown, kernel.pink][self.value]
|
||||||
|
|
||||||
|
|
||||||
def cast_to_enum(inp, Class):
|
def cast_to_enum(inp, Class):
|
||||||
@ -75,7 +66,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self,
|
self,
|
||||||
action_dim: int,
|
action_dim: int,
|
||||||
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
par_strength: Par_Strength = Par_Strength.CONT_DIAG,
|
||||||
kernel_func=rbf(),
|
kernel_func=kernel.rbf(),
|
||||||
init_std: float = 1,
|
init_std: float = 1,
|
||||||
cond_noise: float = 0,
|
cond_noise: float = 0,
|
||||||
window: int = 64,
|
window: int = 64,
|
||||||
|
Loading…
Reference in New Issue
Block a user