diff --git a/.gitignore b/.gitignore index a295864..4382a7a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.pyc __pycache__ +model_savepoints/* diff --git a/discriminator.py b/discriminator.py index ade1d2b..56198df 100644 --- a/discriminator.py +++ b/discriminator.py @@ -1,8 +1,9 @@ import torch from torch import nn -import numpy as np from torch import nn, optim from torch.utils.data import DataLoader +import numpy as np +import random import shark @@ -29,6 +30,9 @@ class Model(nn.Module): torch.zeros(3, 1, 16)) def train(model, seq_len=16*64): + tid = str(int(random.random()*99999)).zfill(5) + ltLoss = 100 + lltLoss = 100 model.train() criterion = nn.BCELoss() @@ -50,7 +54,12 @@ def train(model, seq_len=16*64): loss.backward() optimizer.step() - print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y}) + ltLoss = ltLoss*0.9 + 0.1*loss.item() + lltLoss = lltLoss*0.9 + 0.1*ltLoss + print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss}) + if ltLoss < 0.40 and lltLoss < 0.475: + print("[*] Hell Yeah! Poccing!") + torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') model = Model()