shark/discriminate.py

80 lines
2.0 KiB
Python
Raw Normal View History

import os
import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
from model import Model
bs = int(256/8)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def run(model, seq):
state_h, state_c = model.init_state(len(seq))
for i in range(len(seq)):
x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
#state_h = state_h.detach()
#state_c = state_c.detach()
return y_pred.item()
def score(model, ciphertext, hypothesis):
seq = shark.xor(ciphertext, hypothesis)
return run(model, seq)
def test_scoring(model):
length = 16
iv = shark.genIV()
# TODO: Generate human language
plaintext = os.urandom(length*bs)
plaintextAlt = os.urandom(length*bs)
plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)]
ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv)
high = score(model, ciphertext, plaintext)
low = score(model, ciphertext, plaintextAlt)
mid = score(model, ciphertext, plaintextSemi)
print({'h': high, 'l': low, 'm': mid})
def load():
model = Model()
model.load_state_dict(torch.load('wh_discriminator.n'))
model.eval()
return model
if __name__=="__main__":
m = load()
test_scoring(m)