From e4e2a18606a9c950cc0d9011bba2848597f929ad Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Sep 2021 11:09:15 +0200 Subject: [PATCH] Tweaked training params --- train.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 4d33f32..bbcc96f 100644 --- a/train.py +++ b/train.py @@ -12,12 +12,12 @@ from model import Model def train(model, seq_len=16*64): tid = str(int(random.random()*99999)).zfill(5) print("[i] I am "+str(tid)) - ltLoss = 50 - lltLoss = 52 + ltLoss = 0.75 + lltLoss = 0.80 model.train() criterion = nn.BCELoss() - optimizer = optim.Adam(model.parameters(), lr=0.0001) + optimizer = optim.Adam(model.parameters(), lr=0.001) state_h = [None,None] state_c = [None,None] @@ -45,10 +45,16 @@ def train(model, seq_len=16*64): correct[t] = round(y_pred.item()) == t ltLoss = ltLoss*0.9 + 0.1*loss.item() - lltLoss = lltLoss*0.9 + 0.1*ltLoss - print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) + lltLoss = lltLoss*0.9 + 0.1*ltLoss + print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) if epoch % 8 == 0: torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') + if lltLoss > 0.49: + print("[~] My emperor! I've failed! A BARREL ROLL!") + elif lltLoss < 0.45: + print("[~] Booyaaa!!!!") + else: + print("[~] Meh...") model = Model()