2021-09-21 09:14:31 +02:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn, optim
|
|
|
|
from torch.utils.data import DataLoader
|
2021-09-21 09:49:27 +02:00
|
|
|
import numpy as np
|
|
|
|
import random
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
import shark
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
|
input_size=8,
|
|
|
|
hidden_size=16,
|
|
|
|
num_layers=3,
|
2021-09-21 09:17:01 +02:00
|
|
|
dropout=0.1,
|
2021-09-21 09:14:31 +02:00
|
|
|
)
|
|
|
|
self.fc = nn.Linear(16, 1)
|
|
|
|
self.out = nn.Sigmoid()
|
|
|
|
|
|
|
|
def forward(self, x, prev_state):
|
|
|
|
output, state = self.lstm(x, prev_state)
|
|
|
|
logits = self.fc(output)
|
|
|
|
val = self.out(logits)
|
|
|
|
return val, state
|
|
|
|
|
|
|
|
def init_state(self, sequence_length):
|
|
|
|
return (torch.zeros(3, 1, 16),
|
|
|
|
torch.zeros(3, 1, 16))
|
|
|
|
|
|
|
|
def train(model, seq_len=16*64):
|
2021-09-21 09:49:27 +02:00
|
|
|
tid = str(int(random.random()*99999)).zfill(5)
|
|
|
|
ltLoss = 100
|
|
|
|
lltLoss = 100
|
2021-09-21 09:14:31 +02:00
|
|
|
model.train()
|
|
|
|
|
|
|
|
criterion = nn.BCELoss()
|
2021-09-21 09:17:01 +02:00
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
for epoch in range(1024):
|
|
|
|
state_h, state_c = model.init_state(seq_len)
|
|
|
|
|
|
|
|
blob, y = shark.getSample(seq_len, epoch%2)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
for i in range(len(blob)):
|
|
|
|
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
|
|
|
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
|
|
|
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
|
|
|
|
|
|
|
|
state_h = state_h.detach()
|
|
|
|
state_c = state_c.detach()
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2021-09-21 09:49:27 +02:00
|
|
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
|
|
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
|
|
|
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
|
|
|
|
if ltLoss < 0.40 and lltLoss < 0.475:
|
|
|
|
print("[*] Hell Yeah! Poccing!")
|
|
|
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
2021-09-21 09:14:31 +02:00
|
|
|
|
|
|
|
model = Model()
|
|
|
|
|
|
|
|
train(model)
|