79 lines
2.0 KiB
Python
79 lines
2.0 KiB
Python
|
import os
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch import nn, optim
|
||
|
from torch.utils.data import DataLoader
|
||
|
import numpy as np
|
||
|
import random
|
||
|
|
||
|
import shark
|
||
|
|
||
|
bs = int(256/8)
|
||
|
|
||
|
class Model(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Model, self).__init__()
|
||
|
self.lstm = nn.LSTM(
|
||
|
input_size=8,
|
||
|
hidden_size=16,
|
||
|
num_layers=3,
|
||
|
dropout=0.1,
|
||
|
)
|
||
|
self.fc = nn.Linear(16, 1)
|
||
|
self.out = nn.Sigmoid()
|
||
|
|
||
|
def forward(self, x, prev_state):
|
||
|
output, state = self.lstm(x, prev_state)
|
||
|
logits = self.fc(output)
|
||
|
val = self.out(logits)
|
||
|
#print(str(logits.item())+" > "+str(val.item()))
|
||
|
return val, state
|
||
|
|
||
|
def init_state(self, sequence_length):
|
||
|
return (torch.zeros(3, 1, 16),
|
||
|
torch.zeros(3, 1, 16))
|
||
|
|
||
|
def run(model, seq):
|
||
|
state_h, state_c = model.init_state(len(seq))
|
||
|
for i in range(len(seq)):
|
||
|
x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32)
|
||
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
||
|
|
||
|
#state_h = state_h.detach()
|
||
|
#state_c = state_c.detach()
|
||
|
|
||
|
return y_pred.item()
|
||
|
|
||
|
def score(model, ciphertext, hypothesis):
|
||
|
seq = shark.xor(ciphertext, hypothesis)
|
||
|
return run(model, seq)
|
||
|
|
||
|
def test_scoring(model):
|
||
|
length = 16
|
||
|
iv = shark.genIV()
|
||
|
|
||
|
# TODO: Generate human language
|
||
|
plaintext = os.urandom(length*bs)
|
||
|
plaintextAlt = os.urandom(length*bs)
|
||
|
|
||
|
plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)]
|
||
|
|
||
|
ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv)
|
||
|
|
||
|
high = score(model, ciphertext, plaintext)
|
||
|
low = score(model, ciphertext, plaintextAlt)
|
||
|
mid = score(model, ciphertext, plaintextSemi)
|
||
|
|
||
|
print({'h': high, 'l': low, 'm': mid})
|
||
|
|
||
|
def load():
|
||
|
model = Model()
|
||
|
model.load_state_dict(torch.load('wh_discriminator.n'))
|
||
|
model.eval()
|
||
|
return model
|
||
|
|
||
|
if __name__=="__main__":
|
||
|
m = load()
|
||
|
test_scoring(m)
|