shark/train.py

62 lines
2.0 KiB
Python
Raw Normal View History

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import math
import shark
from model import Model
2021-09-22 11:46:22 +02:00
def train(model, seq_len=16*512): # 1KiB
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
2021-09-22 11:09:15 +02:00
ltLoss = 0.75
lltLoss = 0.80
model.train()
criterion = nn.BCELoss()
2021-09-22 11:09:15 +02:00
optimizer = optim.Adam(model.parameters(), lr=0.001)
state_h = [None,None]
state_c = [None,None]
blob = [None,None]
correct = [None,None]
for epoch in range(1024):
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0)
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1)
optimizer.zero_grad()
for i in range(len(blob[0])):
for t in range(2):
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
state_h[t] = state_h[t].detach()
state_c[t] = state_c[t].detach()
loss.backward()
optimizer.step()
correct[t] = round(y_pred.item()) == t
ltLoss = ltLoss*0.9 + 0.1*loss.item()
2021-09-22 11:09:15 +02:00
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
2021-09-22 11:09:15 +02:00
if lltLoss > 0.49:
print("[~] My emperor! I've failed! A BARREL ROLL!")
elif lltLoss < 0.45:
print("[~] Booyaaa!!!!")
else:
print("[~] Meh...")
model = Model()
train(model)