Better loss tracking and implemented saving of network-wheights
This commit is contained in:
parent
627bf370bc
commit
5bd50e2b43
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,2 +1,3 @@
|
|||||||
*.pyc
|
*.pyc
|
||||||
__pycache__
|
__pycache__
|
||||||
|
model_savepoints/*
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import numpy as np
|
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
import shark
|
import shark
|
||||||
|
|
||||||
@ -29,6 +30,9 @@ class Model(nn.Module):
|
|||||||
torch.zeros(3, 1, 16))
|
torch.zeros(3, 1, 16))
|
||||||
|
|
||||||
def train(model, seq_len=16*64):
|
def train(model, seq_len=16*64):
|
||||||
|
tid = str(int(random.random()*99999)).zfill(5)
|
||||||
|
ltLoss = 100
|
||||||
|
lltLoss = 100
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
@ -50,7 +54,12 @@ def train(model, seq_len=16*64):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y})
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||||
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||||
|
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
||||||
|
if ltLoss < 0.40 and lltLoss < 0.475:
|
||||||
|
print("[*] Hell Yeah! Poccing!")
|
||||||
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user