From 7fd1f4fa3f5ee8fc9ff8e82b39277f3a3d6c083a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 5 Oct 2021 18:08:32 +0200 Subject: [PATCH] Better shit-detection and mitigation when training --- caliGraph.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/caliGraph.py b/caliGraph.py index d602040..4cf0a5d 100755 --- a/caliGraph.py +++ b/caliGraph.py @@ -823,11 +823,12 @@ def evaluateFitness(books, debugPrint=False): fit = sum(errSq)/len(errSq) + 0.005*regressionLoss + 0.2*boundsLoss/len(ratedBooks) - 1.0*sum(linSepLoss)/len(linSepLoss) return fit, gradient -def train(gamma = 1, full=True): +def train(initGamma = 1, full=True): global weights if full: for wt in weights: weights[wt] = random.random() + gamma = initGamma books = loadBooksFromDB() bestWeights = copy.copy(weights) mse, gradient = evaluateFitness(books) @@ -849,12 +850,22 @@ def train(gamma = 1, full=True): if mse < best_mse: saveWeights(weights) bestWeights = copy.copy(weights) - else: + best_mse = mse + if mse > last_mse: stagLen += 1 - if stagLen == 3 or mse > 100: - stagLen = -2 - for wt in weights: - weights[wt] = random.random() + else: + stagLen = 0 + if stagLen == 4 or mse > 50: + print("#") + stagLen = 0 + gamma = initGamma + if random.random() < 0.50: + for wt in weights: + weights[wt] = random.random() + else: + weights = copy.copy(bestWeights) + for wt in weights: + weights[wt] *= 0.975+0.05*random.random() print('Done.') def saveWeights(weights):