Smol changes on parameters
This commit is contained in:
parent
d7c5d57d93
commit
4f250f61c3
19
train.py
19
train.py
@ -9,7 +9,7 @@ import math
|
||||
import shark
|
||||
from model import Model
|
||||
|
||||
def train(model, seq_len=16*128): # 0.25KiB
|
||||
def train(model, seq_len=16*256): # 0.5KiB
|
||||
tid = str(int(random.random()*99999)).zfill(5)
|
||||
print("[i] I am "+str(tid))
|
||||
ltLoss = 0.75
|
||||
@ -17,12 +17,14 @@ def train(model, seq_len=16*128): # 0.25KiB
|
||||
model.train()
|
||||
|
||||
criterion = nn.BCELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
||||
|
||||
state_h = [None,None]
|
||||
state_c = [None,None]
|
||||
blob = [None,None]
|
||||
correct = [None,None]
|
||||
err = [None,None]
|
||||
ltErr = 0.5
|
||||
|
||||
for epoch in range(1024):
|
||||
state_h[0], state_c[0] = model.init_state(seq_len)
|
||||
@ -44,17 +46,16 @@ def train(model, seq_len=16*128): # 0.25KiB
|
||||
optimizer.step()
|
||||
|
||||
correct[t] = round(y_pred.item()) == t
|
||||
err[t] = abs(t - y_pred.item())
|
||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||
ltErr = ltErr*0.99 + (err[0] + err[1])*0.005
|
||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
|
||||
if epoch % 8 == 0:
|
||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||
if lltLoss > 0.49:
|
||||
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" })
|
||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||
if 0.45 < ltErr < 0.55:
|
||||
print("[~] My emperor! I've failed! A BARREL ROLL!")
|
||||
elif lltLoss < 0.45:
|
||||
print("[~] Booyaaa!!!!")
|
||||
else:
|
||||
print("[~] Meh...")
|
||||
print("[~] Booyaaa!!!!")
|
||||
|
||||
model = Model()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user