shark/train.py

63 lines
2.0 KiB
Python
Raw Normal View History

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import math
import shark
from model import Model
2021-09-22 18:53:56 +02:00
def train(model, seq_len=16*256): # 0.5KiB
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
2021-09-22 11:09:15 +02:00
ltLoss = 0.75
lltLoss = 0.80
model.train()
criterion = nn.BCELoss()
2021-09-22 18:53:56 +02:00
optimizer = optim.Adam(model.parameters(), lr=0.0001)
state_h = [None,None]
state_c = [None,None]
blob = [None,None]
correct = [None,None]
2021-09-22 18:53:56 +02:00
err = [None,None]
ltErr = 0.5
for epoch in range(1024):
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
2021-09-22 11:50:03 +02:00
blob[0], _ = shark.getSample(seq_len, 0)
blob[1], _ = shark.getSample(seq_len, 1)
optimizer.zero_grad()
for i in range(len(blob[0])):
for t in range(2):
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
state_h[t] = state_h[t].detach()
state_c[t] = state_c[t].detach()
loss.backward()
optimizer.step()
correct[t] = round(y_pred.item()) == t
2021-09-22 18:53:56 +02:00
err[t] = abs(t - y_pred.item())
ltLoss = ltLoss*0.9 + 0.1*loss.item()
2021-09-22 18:53:56 +02:00
ltErr = ltErr*0.99 + (err[0] + err[1])*0.005
2021-09-22 11:09:15 +02:00
lltLoss = lltLoss*0.9 + 0.1*ltLoss
2021-09-22 18:53:56 +02:00
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" })
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
if 0.45 < ltErr < 0.55:
2021-09-22 11:09:15 +02:00
print("[~] My emperor! I've failed! A BARREL ROLL!")
else:
2021-09-22 18:53:56 +02:00
print("[~] Booyaaa!!!!")
model = Model()
train(model)