shark/discriminator.py

69 lines
2.0 KiB
Python

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
ltLoss = 50
lltLoss = 51
model.train()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(1024):
state_h, state_c = model.init_state(seq_len)
blob, y = shark.getSample(min(seq_len, 16*(epoch+1)), epoch%2)
optimizer.zero_grad()
for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
correct = round(y_pred.item()) == y
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'correct?': correct })
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model()
train(model)