Even better loss tracking and more sparse checkpointing
This commit is contained in:
parent
5bd50e2b43
commit
054b4494d2
@ -31,8 +31,8 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def train(model, seq_len=16*64):
|
def train(model, seq_len=16*64):
|
||||||
tid = str(int(random.random()*99999)).zfill(5)
|
tid = str(int(random.random()*99999)).zfill(5)
|
||||||
ltLoss = 100
|
ltLoss = 50
|
||||||
lltLoss = 100
|
lltLoss = 51
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
@ -57,9 +57,10 @@ def train(model, seq_len=16*64):
|
|||||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||||
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
||||||
if ltLoss < 0.40 and lltLoss < 0.475:
|
if ltLoss < 0.20 and lltLoss < 0.225:
|
||||||
print("[*] Hell Yeah! Poccing!")
|
print("[*] Hell Yeah! Poccing! Got sup")
|
||||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
if epoch % 8 == 0:
|
||||||
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user