diff --git a/discriminator.py b/discriminator.py new file mode 100644 index 0000000..f0ded06 --- /dev/null +++ b/discriminator.py @@ -0,0 +1,57 @@ +import torch +from torch import nn +import numpy as np +from torch import nn, optim +from torch.utils.data import DataLoader + +import shark + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.lstm = nn.LSTM( + input_size=8, + hidden_size=16, + num_layers=3, + dropout=0.05, + ) + self.fc = nn.Linear(16, 1) + self.out = nn.Sigmoid() + + def forward(self, x, prev_state): + output, state = self.lstm(x, prev_state) + logits = self.fc(output) + val = self.out(logits) + return val, state + + def init_state(self, sequence_length): + return (torch.zeros(3, 1, 16), + torch.zeros(3, 1, 16)) + +def train(model, seq_len=16*64): + model.train() + + criterion = nn.BCELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + for epoch in range(1024): + state_h, state_c = model.init_state(seq_len) + + blob, y = shark.getSample(seq_len, epoch%2) + optimizer.zero_grad() + for i in range(len(blob)): + x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32) + y_pred, (state_h, state_c) = model(x, (state_h, state_c)) + loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32)) + + state_h = state_h.detach() + state_c = state_c.detach() + + loss.backward() + optimizer.step() + + print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y}) + +model = Model() + +train(model) diff --git a/shark.py b/shark.py index ea0f162..4549590 100644 --- a/shark.py +++ b/shark.py @@ -1,5 +1,7 @@ import hashlib import math +import os +import random # Shark is a sha256+xor based encryption. # I made it because I want to try to break it. @@ -7,17 +9,30 @@ import math # This will work iff I succeed in building a PPT-discriminator for sha256 from randomness # As my first approach this discriminator will be based on an LSTM-network. +bs = int(256/8) + def xor(ta,tb): return bytes(a ^ b for a, b in zip(ta, tb)) def enc(plaintext, key, iv): ciphertext = bytes() - bs = 256/8 - for i in range(math.ceil(len(plaintext/bs))): + for i in range(math.ceil(len(plaintext)/bs)): m = hashlib.sha256() m.update(xor(key, iv + i.to_bytes(bs, byteorder='big'))) k = m.digest() - ciphertext += sxor(k, plaintext[bs*i:][:bs]) + ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0')) + return ciphertext def dec(ciphertext, key, iv): return enc(ciphertext, key, iv) + +def getSample(length, src=None, key=b'VerySecureKeyMustKeepSecretDontTellAnyone'): + if src==None: + src = random.random() > 0.5 + if not src: + r = os.urandom(length*bs) + return (r, 0) + else: + iv = random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big') + b = bytes(length*bs) + return (enc(b, key, iv), 1)