Smol changes on parameters

This commit is contained in:
Dominik Moritz Roth 2021-09-22 18:53:56 +02:00
parent d7c5d57d93
commit 4f250f61c3

View File

@ -9,7 +9,7 @@ import math
import shark import shark
from model import Model from model import Model
def train(model, seq_len=16*128): # 0.25KiB def train(model, seq_len=16*256): # 0.5KiB
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid)) print("[i] I am "+str(tid))
ltLoss = 0.75 ltLoss = 0.75
@ -17,12 +17,14 @@ def train(model, seq_len=16*128): # 0.25KiB
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.0001)
state_h = [None,None] state_h = [None,None]
state_c = [None,None] state_c = [None,None]
blob = [None,None] blob = [None,None]
correct = [None,None] correct = [None,None]
err = [None,None]
ltErr = 0.5
for epoch in range(1024): for epoch in range(1024):
state_h[0], state_c[0] = model.init_state(seq_len) state_h[0], state_c[0] = model.init_state(seq_len)
@ -44,17 +46,16 @@ def train(model, seq_len=16*128): # 0.25KiB
optimizer.step() optimizer.step()
correct[t] = round(y_pred.item()) == t correct[t] = round(y_pred.item()) == t
err[t] = abs(t - y_pred.item())
ltLoss = ltLoss*0.9 + 0.1*loss.item() ltLoss = ltLoss*0.9 + 0.1*loss.item()
ltErr = ltErr*0.99 + (err[0] + err[1])*0.005
lltLoss = lltLoss*0.9 + 0.1*ltLoss lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" })
if epoch % 8 == 0: torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') if 0.45 < ltErr < 0.55:
if lltLoss > 0.49:
print("[~] My emperor! I've failed! A BARREL ROLL!") print("[~] My emperor! I've failed! A BARREL ROLL!")
elif lltLoss < 0.45:
print("[~] Booyaaa!!!!")
else: else:
print("[~] Meh...") print("[~] Booyaaa!!!!")
model = Model() model = Model()