diff --git a/discriminator.py b/discriminator.py index 56198df..99c1002 100644 --- a/discriminator.py +++ b/discriminator.py @@ -31,8 +31,8 @@ class Model(nn.Module): def train(model, seq_len=16*64): tid = str(int(random.random()*99999)).zfill(5) - ltLoss = 100 - lltLoss = 100 + ltLoss = 50 + lltLoss = 51 model.train() criterion = nn.BCELoss() @@ -57,9 +57,10 @@ def train(model, seq_len=16*64): ltLoss = ltLoss*0.9 + 0.1*loss.item() lltLoss = lltLoss*0.9 + 0.1*ltLoss print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss}) - if ltLoss < 0.40 and lltLoss < 0.475: - print("[*] Hell Yeah! Poccing!") - torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') + if ltLoss < 0.20 and lltLoss < 0.225: + print("[*] Hell Yeah! Poccing! Got sup") + if epoch % 8 == 0: + torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') model = Model()