CaliGraph/py/gp.py
2022-02-24 22:29:52 +01:00

76 lines
2.4 KiB
Python

import numpy as np
from node2vec import Node2Vec
from sklearn.gaussian_process.kernels import Kernel, Hyperparameter
from sklearn.gaussian_process.kernels import GenericKernelMixin
from sklearn.gaussian_process import GaussianProcessRegressor
#from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.base import clone
class BookKernel(GenericKernelMixin, Kernel):
def __init__(self, wv):
self.wv = wv
def _f(self, s1, s2):
"""
kernel value between a pair of sequences
"""
s = self.wv.similarity(s1, s2)**2*0.99 + 0.01
if s <= 0:
print('bad!')
return s
def __call__(self, X, Y=None, eval_gradient=False):
if Y is None:
Y = X
if eval_gradient:
return (
np.array([[self._f(x, y) for y in Y] for x in X])
)
else:
return np.array([[self._f(x, y) for y in Y] for x in X])
#return np.array(self.wv.n_similarity(X, Y))
def diag(self, X):
return self(X)
def is_stationary(self):
return False
def clone_with_theta(self, theta):
cloned = clone(self)
cloned.theta = theta
return cloned
def genGprScores(G, globMu, globStd, scoreName='gpr_score', stdName='gpr_std'):
print('[\] Constructing Vectorizer')
node2vec = Node2Vec(G, dimensions=32, walk_length=16, num_walks=128, workers=8)
print('[\] Fitting Embeddings for Kernel')
model = node2vec.fit(window=8, min_count=1, batch_words=4)
wv = model.wv
print('[\] Constructing Kernel')
kernel = BookKernel(wv)
X, y = [], []
for n in G.nodes:
node = G.nodes[n]
if 'rating' in node and node['rating']!=None:
X.append(n)
y.append(node['rating'])
print('[\] Fitting GP')
gpr = GaussianProcessRegressor(kernel=kernel, random_state=3141, alpha=1e-8).fit(X, y)
X = []
for n in G.nodes:
node = G.nodes[n]
if not 'rating' in node or node['rating']==None:
X.append(n)
print('[\] Inferencing GP')
y, stds = gpr.predict(X, return_std=True)
i=0
for n in G.nodes:
node = G.nodes[n]
if not 'rating' in node or node['rating']==None:
s, std = y[i], sum([val[0] for val in stds[i]])
i+=1
node[scoreName], node[stdName] = float(s), float(std)