From 4f250f61c3a360a048151f36378bdf69ebac7be2 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Sep 2021 18:53:56 +0200 Subject: [PATCH] Smol changes on parameters --- train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index a915aaf..5ddeaba 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import math import shark from model import Model -def train(model, seq_len=16*128): # 0.25KiB +def train(model, seq_len=16*256): # 0.5KiB tid = str(int(random.random()*99999)).zfill(5) print("[i] I am "+str(tid)) ltLoss = 0.75 @@ -17,12 +17,14 @@ def train(model, seq_len=16*128): # 0.25KiB model.train() criterion = nn.BCELoss() - optimizer = optim.Adam(model.parameters(), lr=0.001) + optimizer = optim.Adam(model.parameters(), lr=0.0001) state_h = [None,None] state_c = [None,None] blob = [None,None] correct = [None,None] + err = [None,None] + ltErr = 0.5 for epoch in range(1024): state_h[0], state_c[0] = model.init_state(seq_len) @@ -44,17 +46,16 @@ def train(model, seq_len=16*128): # 0.25KiB optimizer.step() correct[t] = round(y_pred.item()) == t + err[t] = abs(t - y_pred.item()) ltLoss = ltLoss*0.9 + 0.1*loss.item() + ltErr = ltErr*0.99 + (err[0] + err[1])*0.005 lltLoss = lltLoss*0.9 + 0.1*ltLoss - print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) - if epoch % 8 == 0: - torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') - if lltLoss > 0.49: + print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" }) + torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') + if 0.45 < ltErr < 0.55: print("[~] My emperor! I've failed! A BARREL ROLL!") - elif lltLoss < 0.45: - print("[~] Booyaaa!!!!") else: - print("[~] Meh...") + print("[~] Booyaaa!!!!") model = Model()