import torch from torch import nn from torch import nn, optim from torch.utils.data import DataLoader import numpy as np import random import math import shark from model import Model def train(model, seq_len=16*128): # 0.25KiB tid = str(int(random.random()*99999)).zfill(5) print("[i] I am "+str(tid)) ltLoss = 0.75 lltLoss = 0.80 model.train() criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) state_h = [None,None] state_c = [None,None] blob = [None,None] correct = [None,None] for epoch in range(1024): state_h[0], state_c[0] = model.init_state(seq_len) state_h[1], state_c[1] = model.init_state(seq_len) blob[0], _ = shark.getSample(seq_len, 0) blob[1], _ = shark.getSample(seq_len, 1) optimizer.zero_grad() for i in range(len(blob[0])): for t in range(2): x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32) y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t])) loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32)) state_h[t] = state_h[t].detach() state_c[t] = state_c[t].detach() loss.backward() optimizer.step() correct[t] = round(y_pred.item()) == t ltLoss = ltLoss*0.9 + 0.1*loss.item() lltLoss = lltLoss*0.9 + 0.1*ltLoss print({ 'epoch': epoch, 'loss': loss.item(), 'lltLoss': lltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" }) if epoch % 8 == 0: torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') if lltLoss > 0.49: print("[~] My emperor! I've failed! A BARREL ROLL!") elif lltLoss < 0.45: print("[~] Booyaaa!!!!") else: print("[~] Meh...") model = Model() train(model)