training now terminates earlier when stagnating (can be disabled via

flag)
This commit is contained in:
Dominik Moritz Roth 2021-09-26 15:52:54 +02:00
parent 787404c134
commit f3240147d5

View File

@ -638,7 +638,7 @@ def recommendNBooksTagBased(G, mu, std, n, removeTopListsB=True):
def recommendNBooks(G, mu, std, n, removeTopListsB=True, removeUselessRecommenders=True):
removeRestOfSeries(G)
removeBad(G, mu-std*2-1)
removeBad(G, mu-std-1)
removeKeepBest(G, int(n*2) + 5, maxDistForRead=2)
removeEdge(G)
removeHighSpanTags(G, 12)
@ -812,7 +812,7 @@ def evaluateFitness(books, debugPrint=False):
print(sum(errSq)/len(errSq), 0.005*regressionLoss, 0.2*boundsLoss/len(ratedBooks), 1.0*sum(linSepLoss)/len(linSepLoss))
return sum(errSq)/len(errSq) + 0.005*regressionLoss + 0.2*boundsLoss/len(ratedBooks) - 1.0*sum(linSepLoss)/len(linSepLoss)
def train(gamma = 1):
def train(gamma = 1, maxEmptySteps=-1):
global weights
books = loadBooksFromDB()
bestWeights = copy.copy(weights)
@ -820,6 +820,7 @@ def train(gamma = 1):
w = list(weights.keys())
attr = random.choice(w)
delta = gamma * (-0.5 + (0.75 + 0.25*random.random()))
emptyStepsLeft = maxEmptySteps
while gamma > 1.0e-08:
print({'mse': best_mse, 'w': weights, 'gamma': gamma})
@ -841,11 +842,15 @@ def train(gamma = 1):
delta *= 2
if random.random() < 0.10:
attr = random.choice(w)
emptyStepsLeft = maxEmptySteps
else:
weights = copy.copy(bestWeights)
gamma *= 0.8
attr = random.choice(w)
delta = gamma * (-0.5 + (0.75 + 0.25*random.random()))
emptyStepsLeft -= 1
if emptyStepsLeft == 0:
return
def saveWeights(weights):
with open('neuralWeights.json', 'w') as f:
@ -887,13 +892,14 @@ def cliInterface():
p_train = cmds.add_parser('train', description="TODO", aliases=[])
p_train.add_argument('-g', type=float, default=1, help='learning rate gamma')
p_train.add_argument('--full', action="store_true")
p_full = cmds.add_parser('full', description="TODO", aliases=[])
args = parser.parse_args()
if args.cmd=="train":
train(args.g)
train(args.g, -1 if args.full else 32)
exit()
G, books = buildFullGraph()