Even better loss tracking and more sparse checkpointing

This commit is contained in:
Dominik Moritz Roth 2021-09-21 11:05:28 +02:00
parent 5bd50e2b43
commit 054b4494d2

View File

@ -31,8 +31,8 @@ class Model(nn.Module):
def train(model, seq_len=16*64): def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)
ltLoss = 100 ltLoss = 50
lltLoss = 100 lltLoss = 51
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
@ -57,8 +57,9 @@ def train(model, seq_len=16*64):
ltLoss = ltLoss*0.9 + 0.1*loss.item() ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss}) print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
if ltLoss < 0.40 and lltLoss < 0.475: if ltLoss < 0.20 and lltLoss < 0.225:
print("[*] Hell Yeah! Poccing!") print("[*] Hell Yeah! Poccing! Got sup")
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model() model = Model()