diff --git a/caliGraph.py b/caliGraph.py index 6e08c59..bdc2df5 100755 --- a/caliGraph.py +++ b/caliGraph.py @@ -682,7 +682,10 @@ def evaluateFitness(): rating = G.nodes[m]['rating'] G.nodes[m]['rating'] = None mu, std = genScores(G, books) - errSq.append((rating - G.nodes[m]['score'])**2) + if G.nodes[m]['score'] > rating: # over estimated + errSq.append(((rating - G.nodes[m]['score'])**2)*1.5) + else: + errSq.append((rating - G.nodes[m]['score'])**2) G.nodes[m]['rating'] = rating return sum(errSq) / len(errSq) @@ -697,11 +700,17 @@ def train(gamma = 0.1): while True: print({'mse': best_mse, 'w': weights, 'gamma': gamma}) weights = copy.copy(bestWeights) - weights[attr] += delta + if gamma < 0.01 and random.random() < 0.5: + gamma = 0.01 + weights[attr] = -1+random.random()*2 + else: + weights[attr] += delta + if attr not in ['sigma, mu']: + weights[attr] = min(max(0, weight[attr]), 1.5) mse = evaluateFitness() if mse < best_mse: # got better saveWeights(weights) - gamma *= 1.1 + gamma *= 1.75 bestWeights = copy.copy(weights) best_mse = mse delta *= 2