diff --git a/discriminate.py b/discriminate.py new file mode 100644 index 0000000..f048a1d --- /dev/null +++ b/discriminate.py @@ -0,0 +1,78 @@ +import os + +import torch +from torch import nn +from torch import nn, optim +from torch.utils.data import DataLoader +import numpy as np +import random + +import shark + +bs = int(256/8) + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.lstm = nn.LSTM( + input_size=8, + hidden_size=16, + num_layers=3, + dropout=0.1, + ) + self.fc = nn.Linear(16, 1) + self.out = nn.Sigmoid() + + def forward(self, x, prev_state): + output, state = self.lstm(x, prev_state) + logits = self.fc(output) + val = self.out(logits) + #print(str(logits.item())+" > "+str(val.item())) + return val, state + + def init_state(self, sequence_length): + return (torch.zeros(3, 1, 16), + torch.zeros(3, 1, 16)) + +def run(model, seq): + state_h, state_c = model.init_state(len(seq)) + for i in range(len(seq)): + x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32) + y_pred, (state_h, state_c) = model(x, (state_h, state_c)) + + #state_h = state_h.detach() + #state_c = state_c.detach() + + return y_pred.item() + +def score(model, ciphertext, hypothesis): + seq = shark.xor(ciphertext, hypothesis) + return run(model, seq) + +def test_scoring(model): + length = 16 + iv = shark.genIV() + + # TODO: Generate human language + plaintext = os.urandom(length*bs) + plaintextAlt = os.urandom(length*bs) + + plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)] + + ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv) + + high = score(model, ciphertext, plaintext) + low = score(model, ciphertext, plaintextAlt) + mid = score(model, ciphertext, plaintextSemi) + + print({'h': high, 'l': low, 'm': mid}) + +def load(): + model = Model() + model.load_state_dict(torch.load('wh_discriminator.n')) + model.eval() + return model + +if __name__=="__main__": + m = load() + test_scoring(m) diff --git a/shark.py b/shark.py index 6b9b986..06e37fd 100644 --- a/shark.py +++ b/shark.py @@ -14,14 +14,17 @@ bs = int(256/8) def xor(ta,tb): return bytes(a ^ b for a, b in zip(ta, tb)) +def genIV(): + return random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big') + def enc(plaintext, key, iv): ciphertext = bytes() for i in range(math.ceil(len(plaintext)/bs)): m = hashlib.sha256() m.update(xor(key, iv + i.to_bytes(bs, byteorder='big'))) k = m.digest() - ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0')) iv = (int.from_bytes(iv, byteorder='big')+1).to_bytes(bs, byteorder='big') + ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0')) return ciphertext def dec(ciphertext, key, iv): @@ -34,6 +37,6 @@ def getSample(length, src=None, key=b'VerySecureKeyMustKeepSecretDontTellAnyone' r = os.urandom(length*bs) return (r, 0) else: - iv = random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big') + iv = genIV() b = bytes(length*bs) return (enc(b, key, iv), 1) diff --git a/discriminator.py b/train.py similarity index 54% rename from discriminator.py rename to train.py index 8039044..ec8ae83 100644 --- a/discriminator.py +++ b/train.py @@ -40,26 +40,34 @@ def train(model, seq_len=16*64): criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) + state_h = [None,None] + state_c = [None,None] + blob = [None,None] + correct = [None,None] + for epoch in range(1024): - state_h, state_c = model.init_state(seq_len) + state_h[0], state_c[0] = model.init_state(seq_len) + state_h[1], state_c[1] = model.init_state(seq_len) - blob, y = shark.getSample(min(seq_len, 16*(epoch+1)), epoch%2) + blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0) + blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1) optimizer.zero_grad() - for i in range(len(blob)): - x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32) - y_pred, (state_h, state_c) = model(x, (state_h, state_c)) - loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32)) + for i in range(len(blob[0])): + for t in range(2): + x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32) + y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t])) + loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32)) - state_h = state_h.detach() - state_c = state_c.detach() + state_h[t] = state_h[t].detach() + state_c[t] = state_c[t].detach() - loss.backward() - optimizer.step() + loss.backward() + optimizer.step() - correct = round(y_pred.item()) == y - ltLoss = ltLoss*0.9 + 0.1*loss.item() - lltLoss = lltLoss*0.9 + 0.1*ltLoss - print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'correct?': correct }) + correct[t] = round(y_pred.item()) == t + ltLoss = ltLoss*0.9 + 0.1*loss.item() + lltLoss = lltLoss*0.9 + 0.1*ltLoss + print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1] }) if epoch % 8 == 0: torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')