Better loss tracking and implemented saving of network-wheights

This commit is contained in:
Dominik Moritz Roth 2021-09-21 09:49:27 +02:00
parent 627bf370bc
commit 5bd50e2b43
2 changed files with 12 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pyc *.pyc
__pycache__ __pycache__
model_savepoints/*

View File

@ -1,8 +1,9 @@
import torch import torch
from torch import nn from torch import nn
import numpy as np
from torch import nn, optim from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np
import random
import shark import shark
@ -29,6 +30,9 @@ class Model(nn.Module):
torch.zeros(3, 1, 16)) torch.zeros(3, 1, 16))
def train(model, seq_len=16*64): def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5)
ltLoss = 100
lltLoss = 100
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
@ -50,7 +54,12 @@ def train(model, seq_len=16*64):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y}) ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
if ltLoss < 0.40 and lltLoss < 0.475:
print("[*] Hell Yeah! Poccing!")
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model() model = Model()