From 054b4494d22ca7a478ace1c32dad5ed725e363aa Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 21 Sep 2021 11:05:28 +0200 Subject: [PATCH] Even better loss tracking and more sparse checkpointing --- discriminator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/discriminator.py b/discriminator.py index 56198df..99c1002 100644 --- a/discriminator.py +++ b/discriminator.py @@ -31,8 +31,8 @@ class Model(nn.Module): def train(model, seq_len=16*64): tid = str(int(random.random()*99999)).zfill(5) - ltLoss = 100 - lltLoss = 100 + ltLoss = 50 + lltLoss = 51 model.train() criterion = nn.BCELoss() @@ -57,9 +57,10 @@ def train(model, seq_len=16*64): ltLoss = ltLoss*0.9 + 0.1*loss.item() lltLoss = lltLoss*0.9 + 0.1*ltLoss print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss}) - if ltLoss < 0.40 and lltLoss < 0.475: - print("[*] Hell Yeah! Poccing!") - torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') + if ltLoss < 0.20 and lltLoss < 0.225: + print("[*] Hell Yeah! Poccing! Got sup") + if epoch % 8 == 0: + torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') model = Model()