From 8c4e35bf41135c4788fb578648c7ae49b35f3e18 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 27 Sep 2021 00:22:35 +0200 Subject: [PATCH] Bug fix in training --- caliGraph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/caliGraph.py b/caliGraph.py index 718ac97..6c3f5cf 100755 --- a/caliGraph.py +++ b/caliGraph.py @@ -835,7 +835,7 @@ def train(gamma = 1, full=True): best_mse = mse stagLen = 0 - while gamma > 1.0e-06 and delta > 1.0e-06 or best_mse > 3: + while gamma > 1.0e-06 and delta > 1.0e-05 or best_mse > 3: last_mse = mse print({'mse': mse, 'gamma': gamma, 'delta': delta}) delta = sum(gradient[g]**2 for g in gradient) @@ -852,6 +852,7 @@ def train(gamma = 1, full=True): else: stagLen += 1 if stagLen == 3 or mse > 100: + stagLen = -2 for wt in weights: weights[wt] = random.random() print('Done.')