shark/train.py

62 lines
2.0 KiB
Python

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import math
import shark
from model import Model
def train(model, seq_len=16*512): # 1KiB
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
ltLoss = 0.75
lltLoss = 0.80
model.train()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
state_h = [None,None]
state_c = [None,None]
blob = [None,None]
correct = [None,None]
for epoch in range(1024):
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
blob[0], _ = shark.getSample(seq_len, 0)
blob[1], _ = shark.getSample(seq_len, 1)
optimizer.zero_grad()
for i in range(len(blob[0])):
for t in range(2):
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
state_h[t] = state_h[t].detach()
state_c[t] = state_c[t].detach()
loss.backward()
optimizer.step()
correct[t] = round(y_pred.item()) == t
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
if lltLoss > 0.49:
print("[~] My emperor! I've failed! A BARREL ROLL!")
elif lltLoss < 0.45:
print("[~] Booyaaa!!!!")
else:
print("[~] Meh...")
model = Model()
train(model)