shark/train.py

63 lines
2.0 KiB
Python

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import math
import shark
from model import Model
def train(model, seq_len=16*256): # 0.5KiB
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
ltLoss = 0.75
lltLoss = 0.80
model.train()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
state_h = [None,None]
state_c = [None,None]
blob = [None,None]
correct = [None,None]
err = [None,None]
ltErr = 0.5
for epoch in range(1024):
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
blob[0], _ = shark.getSample(seq_len, 0)
blob[1], _ = shark.getSample(seq_len, 1)
optimizer.zero_grad()
for i in range(len(blob[0])):
for t in range(2):
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
state_h[t] = state_h[t].detach()
state_c[t] = state_c[t].detach()
loss.backward()
optimizer.step()
correct[t] = round(y_pred.item()) == t
err[t] = abs(t - y_pred.item())
ltLoss = ltLoss*0.9 + 0.1*loss.item()
ltErr = ltErr*0.99 + (err[0] + err[1])*0.005
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" })
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
if 0.45 < ltErr < 0.55:
print("[~] My emperor! I've failed! A BARREL ROLL!")
else:
print("[~] Booyaaa!!!!")
model = Model()
train(model)