shark/discriminator.py

67 lines
1.9 KiB
Python
Raw Normal View History

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
2021-09-21 09:17:01 +02:00
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5)
ltLoss = 100
lltLoss = 100
model.train()
criterion = nn.BCELoss()
2021-09-21 09:17:01 +02:00
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1024):
state_h, state_c = model.init_state(seq_len)
blob, y = shark.getSample(seq_len, epoch%2)
optimizer.zero_grad()
for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss})
if ltLoss < 0.40 and lltLoss < 0.475:
print("[*] Hell Yeah! Poccing!")
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
model = Model()
train(model)