shark/discriminator.py

58 lines
1.5 KiB
Python
Raw Normal View History

import torch
from torch import nn
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
import shark
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.05,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64):
model.train()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1024):
state_h, state_c = model.init_state(seq_len)
blob, y = shark.getSample(seq_len, epoch%2)
optimizer.zero_grad()
for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y})
model = Model()
train(model)