import os import torch from torch import nn from torch import nn, optim from torch.utils.data import DataLoader import numpy as np import random import shark from model import Model bs = shark.bs class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.lstm = nn.LSTM( input_size=8, hidden_size=16, num_layers=3, dropout=0.1, ) self.fc = nn.Linear(16, 1) self.out = nn.Sigmoid() def forward(self, x, prev_state): output, state = self.lstm(x, prev_state) logits = self.fc(output) val = self.out(logits) #print(str(logits.item())+" > "+str(val.item())) return val, state def init_state(self, sequence_length): return (torch.zeros(3, 1, 16), torch.zeros(3, 1, 16)) def run(model, seq): state_h, state_c = model.init_state(len(seq)) for i in range(len(seq)): x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32) y_pred, (state_h, state_c) = model(x, (state_h, state_c)) #state_h = state_h.detach() #state_c = state_c.detach() return y_pred.item() def score(model, ciphertext, hypothesis): seq = shark.xor(ciphertext, hypothesis) return run(model, seq) def test_scoring(model): length = 16 iv = shark.genIV() # TODO: Generate human language plaintext = os.urandom(length*bs) plaintextAlt = os.urandom(length*bs) plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)] ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv) high = score(model, ciphertext, plaintext) low = score(model, ciphertext, plaintextAlt) mid = score(model, ciphertext, plaintextSemi) print({'h': high, 'l': low, 'm': mid}) def load(): model = Model() model.load_state_dict(torch.load('wh_discriminator.n')) model.eval() return model if __name__=="__main__": m = load() test_scoring(m)