shark/discriminate.py

79 lines
2.0 KiB
Python

import os
import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
bs = int(256/8)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def run(model, seq):
state_h, state_c = model.init_state(len(seq))
for i in range(len(seq)):
x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
#state_h = state_h.detach()
#state_c = state_c.detach()
return y_pred.item()
def score(model, ciphertext, hypothesis):
seq = shark.xor(ciphertext, hypothesis)
return run(model, seq)
def test_scoring(model):
length = 16
iv = shark.genIV()
# TODO: Generate human language
plaintext = os.urandom(length*bs)
plaintextAlt = os.urandom(length*bs)
plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)]
ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv)
high = score(model, ciphertext, plaintext)
low = score(model, ciphertext, plaintextAlt)
mid = score(model, ciphertext, plaintextSemi)
print({'h': high, 'l': low, 'm': mid})
def load():
model = Model()
model.load_state_dict(torch.load('wh_discriminator.n'))
model.eval()
return model
if __name__=="__main__":
m = load()
test_scoring(m)