Better loss tracking and implemented saving of network-wheights
This commit is contained in:
parent
627bf370bc
commit
5bd50e2b43
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
*.pyc
|
||||
__pycache__
|
||||
model_savepoints/*
|
||||
|
@ -1,8 +1,9 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from torch import nn, optim
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import shark
|
||||
|
||||
@ -29,6 +30,9 @@ class Model(nn.Module):
|
||||
torch.zeros(3, 1, 16))
|
||||
|
||||
def train(model, seq_len=16*64):
|
||||
tid = str(int(random.random()*99999)).zfill(5)
|
||||
ltLoss = 100
|
||||
lltLoss = 100
|
||||
model.train()
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
@ -50,7 +54,12 @@ def train(model, seq_len=16*64):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y})
|
||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
||||
if ltLoss < 0.40 and lltLoss < 0.475:
|
||||
print("[*] Hell Yeah! Poccing!")
|
||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||
|
||||
model = Model()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user