Tweaked training
This commit is contained in:
parent
cf85995386
commit
fb3a5592df
15
caliGraph.py
15
caliGraph.py
@ -682,7 +682,10 @@ def evaluateFitness():
|
||||
rating = G.nodes[m]['rating']
|
||||
G.nodes[m]['rating'] = None
|
||||
mu, std = genScores(G, books)
|
||||
errSq.append((rating - G.nodes[m]['score'])**2)
|
||||
if G.nodes[m]['score'] > rating: # over estimated
|
||||
errSq.append(((rating - G.nodes[m]['score'])**2)*1.5)
|
||||
else:
|
||||
errSq.append((rating - G.nodes[m]['score'])**2)
|
||||
G.nodes[m]['rating'] = rating
|
||||
return sum(errSq) / len(errSq)
|
||||
|
||||
@ -697,11 +700,17 @@ def train(gamma = 0.1):
|
||||
while True:
|
||||
print({'mse': best_mse, 'w': weights, 'gamma': gamma})
|
||||
weights = copy.copy(bestWeights)
|
||||
weights[attr] += delta
|
||||
if gamma < 0.01 and random.random() < 0.5:
|
||||
gamma = 0.01
|
||||
weights[attr] = -1+random.random()*2
|
||||
else:
|
||||
weights[attr] += delta
|
||||
if attr not in ['sigma, mu']:
|
||||
weights[attr] = min(max(0, weight[attr]), 1.5)
|
||||
mse = evaluateFitness()
|
||||
if mse < best_mse: # got better
|
||||
saveWeights(weights)
|
||||
gamma *= 1.1
|
||||
gamma *= 1.75
|
||||
bestWeights = copy.copy(weights)
|
||||
best_mse = mse
|
||||
delta *= 2
|
||||
|
Loading…
Reference in New Issue
Block a user