import torch from torch import nn from torch import nn, optim from torch.utils.data import DataLoader import numpy as np import random import math import shark from model import Model def train(model, seq_len=16*256): # 0.5KiB tid = str(int(random.random()*99999)).zfill(5) print("[i] I am "+str(tid)) ltLoss = 0.75 lltLoss = 0.80 model.train() criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) state_h = [None,None] state_c = [None,None] blob = [None,None] correct = [None,None] err = [None,None] ltErr = 0.5 for epoch in range(1024): state_h[0], state_c[0] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len) blob[0], _ = shark.getSample(seq_len, 0) blob[1], _ = shark.getSample(seq_len, 1) optimizer.zero_grad() for i in range(len(blob[0])): for t in range(2): x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32) y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t])) loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32)) state_h[t] = state_h[t].detach() state_c[t] = state_c[t].detach() loss.backward() optimizer.step() correct[t] = round(y_pred.item()) == t err[t] = abs(t - y_pred.item()) ltLoss = ltLoss*0.9 + 0.1*loss.item() ltErr = ltErr*0.99 + (err[0] + err[1])*0.005 lltLoss = lltLoss*0.9 + 0.1*ltLoss print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(100-(err[0]+err[1])*50))+"%" }) torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') if 0.45 < ltErr < 0.55: print("[~] My emperor! I've failed! A BARREL ROLL!") else: print("[~] Booyaaa!!!!") model = Model() train(model)