import torch from torch import nn from torch import nn, optim from torch.utils.data import DataLoader import numpy as np import random import math import shark from model import Model def train(model, seq_len=16*64): tid = str(int(random.random()*99999)).zfill(5) print("[i] I am "+str(tid)) ltLoss = 50 lltLoss = 52 model.train() criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) state_h = [None,None] state_c = [None,None] blob = [None,None] correct = [None,None] for epoch in range(1024): state_h[0], state_c[0] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len) blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0) blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1) optimizer.zero_grad() for i in range(len(blob[0])): for t in range(2): x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32) y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t])) loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32)) state_h[t] = state_h[t].detach() state_c[t] = state_c[t].detach() loss.backward() optimizer.step() correct[t] = round(y_pred.item()) == t ltLoss = ltLoss*0.9 + 0.1*loss.item() lltLoss = lltLoss*0.9 + 0.1*ltLoss print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) if epoch % 8 == 0: torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') model = Model() train(model)