shark/discriminator.py

69 lines
2.0 KiB
Python
Raw Normal View History

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
2021-09-21 09:17:01 +02:00
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
ltLoss = 50
lltLoss = 51
model.train()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(1024):
state_h, state_c = model.init_state(seq_len)
blob, y = shark.getSample(min(seq_len, 16*(epoch+1)), epoch%2)
optimizer.zero_grad()
for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
correct = round(y_pred.item()) == y
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'correct?': correct })
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model()
train(model)