2021-09-21 09:14:31 +02:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn, optim
|
|
|
|
from torch.utils.data import DataLoader
|
2021-09-21 09:49:27 +02:00
|
|
|
import numpy as np
|
|
|
|
import random
|
2021-09-22 10:28:08 +02:00
|
|
|
import math
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
import shark
|
2021-09-22 10:37:11 +02:00
|
|
|
from model import Model
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
def train(model, seq_len=16*64):
|
2021-09-21 09:49:27 +02:00
|
|
|
tid = str(int(random.random()*99999)).zfill(5)
|
2021-09-21 15:54:29 +02:00
|
|
|
print("[i] I am "+str(tid))
|
2021-09-21 11:05:28 +02:00
|
|
|
ltLoss = 50
|
2021-09-22 10:28:08 +02:00
|
|
|
lltLoss = 52
|
2021-09-21 09:14:31 +02:00
|
|
|
model.train()
|
|
|
|
|
|
|
|
criterion = nn.BCELoss()
|
2021-09-21 15:54:29 +02:00
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
2021-09-21 09:14:31 +02:00
|
|
|
|
2021-09-22 09:14:23 +02:00
|
|
|
state_h = [None,None]
|
|
|
|
state_c = [None,None]
|
|
|
|
blob = [None,None]
|
|
|
|
correct = [None,None]
|
|
|
|
|
2021-09-21 09:14:31 +02:00
|
|
|
for epoch in range(1024):
|
2021-09-22 09:14:23 +02:00
|
|
|
state_h[0], state_c[0] = model.init_state(seq_len)
|
|
|
|
state_h[1], state_c[1] = model.init_state(seq_len)
|
2021-09-21 09:14:31 +02:00
|
|
|
|
2021-09-22 09:14:23 +02:00
|
|
|
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0)
|
|
|
|
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1)
|
2021-09-21 09:14:31 +02:00
|
|
|
optimizer.zero_grad()
|
2021-09-22 09:14:23 +02:00
|
|
|
for i in range(len(blob[0])):
|
|
|
|
for t in range(2):
|
|
|
|
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
|
|
|
|
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
|
|
|
|
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
|
2021-09-21 09:14:31 +02:00
|
|
|
|
2021-09-22 09:14:23 +02:00
|
|
|
state_h[t] = state_h[t].detach()
|
|
|
|
state_c[t] = state_c[t].detach()
|
2021-09-21 09:14:31 +02:00
|
|
|
|
2021-09-22 09:14:23 +02:00
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
2021-09-21 09:14:31 +02:00
|
|
|
|
2021-09-22 09:14:23 +02:00
|
|
|
correct[t] = round(y_pred.item()) == t
|
|
|
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
|
|
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
2021-09-22 10:28:08 +02:00
|
|
|
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
|
2021-09-21 11:05:28 +02:00
|
|
|
if epoch % 8 == 0:
|
|
|
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
model = Model()
|
|
|
|
|
|
|
|
train(model)
|