Better loss tracking and implemented saving of network-wheights

This commit is contained in:
Dominik Moritz Roth 2021-09-21 09:49:27 +02:00
parent 627bf370bc
commit 5bd50e2b43
2 changed files with 12 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
*.pyc
__pycache__
model_savepoints/*

View File

@ -1,8 +1,9 @@
import torch
from torch import nn
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
@ -29,6 +30,9 @@ class Model(nn.Module):
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5)
ltLoss = 100
lltLoss = 100
model.train()
criterion = nn.BCELoss()
@ -50,7 +54,12 @@ def train(model, seq_len=16*64):
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y})
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
if ltLoss < 0.40 and lltLoss < 0.475:
print("[*] Hell Yeah! Poccing!")
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model()