Tweaked training params
This commit is contained in:
parent
c95d5a2e8b
commit
e4e2a18606
14
train.py
14
train.py
@ -12,12 +12,12 @@ from model import Model
|
|||||||
def train(model, seq_len=16*64):
|
def train(model, seq_len=16*64):
|
||||||
tid = str(int(random.random()*99999)).zfill(5)
|
tid = str(int(random.random()*99999)).zfill(5)
|
||||||
print("[i] I am "+str(tid))
|
print("[i] I am "+str(tid))
|
||||||
ltLoss = 50
|
ltLoss = 0.75
|
||||||
lltLoss = 52
|
lltLoss = 0.80
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
state_h = [None,None]
|
state_h = [None,None]
|
||||||
state_c = [None,None]
|
state_c = [None,None]
|
||||||
@ -46,9 +46,15 @@ def train(model, seq_len=16*64):
|
|||||||
correct[t] = round(y_pred.item()) == t
|
correct[t] = round(y_pred.item()) == t
|
||||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||||
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
|
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
|
||||||
if epoch % 8 == 0:
|
if epoch % 8 == 0:
|
||||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||||
|
if lltLoss > 0.49:
|
||||||
|
print("[~] My emperor! I've failed! A BARREL ROLL!")
|
||||||
|
elif lltLoss < 0.45:
|
||||||
|
print("[~] Booyaaa!!!!")
|
||||||
|
else:
|
||||||
|
print("[~] Meh...")
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user