Tweaked training params

This commit is contained in:
Dominik Moritz Roth 2021-09-22 11:09:15 +02:00
parent c95d5a2e8b
commit e4e2a18606

View File

@ -12,12 +12,12 @@ from model import Model
def train(model, seq_len=16*64): def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid)) print("[i] I am "+str(tid))
ltLoss = 50 ltLoss = 0.75
lltLoss = 52 lltLoss = 0.80
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) optimizer = optim.Adam(model.parameters(), lr=0.001)
state_h = [None,None] state_h = [None,None]
state_c = [None,None] state_c = [None,None]
@ -45,10 +45,16 @@ def train(model, seq_len=16*64):
correct[t] = round(y_pred.item()) == t correct[t] = round(y_pred.item()) == t
ltLoss = ltLoss*0.9 + 0.1*loss.item() ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
if epoch % 8 == 0: if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
if lltLoss > 0.49:
print("[~] My emperor! I've failed! A BARREL ROLL!")
elif lltLoss < 0.45:
print("[~] Booyaaa!!!!")
else:
print("[~] Meh...")
model = Model() model = Model()