2021-09-22 09:14:23 +02:00
|
|
|
import os
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn, optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
|
|
|
|
import shark
|
2021-09-22 10:37:11 +02:00
|
|
|
from model import Model
|
2021-09-22 09:14:23 +02:00
|
|
|
|
2021-09-22 22:27:53 +02:00
|
|
|
bs = shark.bs
|
2021-09-22 09:14:23 +02:00
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
|
input_size=8,
|
|
|
|
hidden_size=16,
|
|
|
|
num_layers=3,
|
|
|
|
dropout=0.1,
|
|
|
|
)
|
|
|
|
self.fc = nn.Linear(16, 1)
|
|
|
|
self.out = nn.Sigmoid()
|
|
|
|
|
|
|
|
def forward(self, x, prev_state):
|
|
|
|
output, state = self.lstm(x, prev_state)
|
|
|
|
logits = self.fc(output)
|
|
|
|
val = self.out(logits)
|
|
|
|
#print(str(logits.item())+" > "+str(val.item()))
|
|
|
|
return val, state
|
|
|
|
|
|
|
|
def init_state(self, sequence_length):
|
|
|
|
return (torch.zeros(3, 1, 16),
|
|
|
|
torch.zeros(3, 1, 16))
|
|
|
|
|
|
|
|
def run(model, seq):
|
|
|
|
state_h, state_c = model.init_state(len(seq))
|
|
|
|
for i in range(len(seq)):
|
|
|
|
x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32)
|
|
|
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
|
|
|
|
|
|
|
#state_h = state_h.detach()
|
|
|
|
#state_c = state_c.detach()
|
|
|
|
|
|
|
|
return y_pred.item()
|
|
|
|
|
|
|
|
def score(model, ciphertext, hypothesis):
|
|
|
|
seq = shark.xor(ciphertext, hypothesis)
|
|
|
|
return run(model, seq)
|
|
|
|
|
|
|
|
def test_scoring(model):
|
|
|
|
length = 16
|
|
|
|
iv = shark.genIV()
|
|
|
|
|
|
|
|
# TODO: Generate human language
|
|
|
|
plaintext = os.urandom(length*bs)
|
|
|
|
plaintextAlt = os.urandom(length*bs)
|
|
|
|
|
|
|
|
plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)]
|
|
|
|
|
|
|
|
ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv)
|
|
|
|
|
|
|
|
high = score(model, ciphertext, plaintext)
|
|
|
|
low = score(model, ciphertext, plaintextAlt)
|
|
|
|
mid = score(model, ciphertext, plaintextSemi)
|
|
|
|
|
|
|
|
print({'h': high, 'l': low, 'm': mid})
|
|
|
|
|
|
|
|
def load():
|
|
|
|
model = Model()
|
|
|
|
model.load_state_dict(torch.load('wh_discriminator.n'))
|
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
|
m = load()
|
|
|
|
test_scoring(m)
|