Even better loss tracking and more sparse checkpointing
This commit is contained in:
parent
5bd50e2b43
commit
054b4494d2
@ -31,8 +31,8 @@ class Model(nn.Module):
|
||||
|
||||
def train(model, seq_len=16*64):
|
||||
tid = str(int(random.random()*99999)).zfill(5)
|
||||
ltLoss = 100
|
||||
lltLoss = 100
|
||||
ltLoss = 50
|
||||
lltLoss = 51
|
||||
model.train()
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
@ -57,9 +57,10 @@ def train(model, seq_len=16*64):
|
||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
||||
if ltLoss < 0.40 and lltLoss < 0.475:
|
||||
print("[*] Hell Yeah! Poccing!")
|
||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||
if ltLoss < 0.20 and lltLoss < 0.225:
|
||||
print("[*] Hell Yeah! Poccing! Got sup")
|
||||
if epoch % 8 == 0:
|
||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||
|
||||
model = Model()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user