Started implementation of the PPT-Attacker
This commit is contained in:
parent
5f122e3cbf
commit
f4444f3a9e
57
discriminator.py
Normal file
57
discriminator.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn, optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
import shark
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=8,
|
||||||
|
hidden_size=16,
|
||||||
|
num_layers=3,
|
||||||
|
dropout=0.05,
|
||||||
|
)
|
||||||
|
self.fc = nn.Linear(16, 1)
|
||||||
|
self.out = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x, prev_state):
|
||||||
|
output, state = self.lstm(x, prev_state)
|
||||||
|
logits = self.fc(output)
|
||||||
|
val = self.out(logits)
|
||||||
|
return val, state
|
||||||
|
|
||||||
|
def init_state(self, sequence_length):
|
||||||
|
return (torch.zeros(3, 1, 16),
|
||||||
|
torch.zeros(3, 1, 16))
|
||||||
|
|
||||||
|
def train(model, seq_len=16*64):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
criterion = nn.BCELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for epoch in range(1024):
|
||||||
|
state_h, state_c = model.init_state(seq_len)
|
||||||
|
|
||||||
|
blob, y = shark.getSample(seq_len, epoch%2)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
for i in range(len(blob)):
|
||||||
|
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
|
||||||
|
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
||||||
|
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
|
||||||
|
|
||||||
|
state_h = state_h.detach()
|
||||||
|
state_c = state_c.detach()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print({ 'epoch': epoch, 'loss': loss.item(), 'err': float(y_pred[0][0])- y})
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
|
||||||
|
train(model)
|
21
shark.py
21
shark.py
@ -1,5 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
# Shark is a sha256+xor based encryption.
|
# Shark is a sha256+xor based encryption.
|
||||||
# I made it because I want to try to break it.
|
# I made it because I want to try to break it.
|
||||||
@ -7,17 +9,30 @@ import math
|
|||||||
# This will work iff I succeed in building a PPT-discriminator for sha256 from randomness
|
# This will work iff I succeed in building a PPT-discriminator for sha256 from randomness
|
||||||
# As my first approach this discriminator will be based on an LSTM-network.
|
# As my first approach this discriminator will be based on an LSTM-network.
|
||||||
|
|
||||||
|
bs = int(256/8)
|
||||||
|
|
||||||
def xor(ta,tb):
|
def xor(ta,tb):
|
||||||
return bytes(a ^ b for a, b in zip(ta, tb))
|
return bytes(a ^ b for a, b in zip(ta, tb))
|
||||||
|
|
||||||
def enc(plaintext, key, iv):
|
def enc(plaintext, key, iv):
|
||||||
ciphertext = bytes()
|
ciphertext = bytes()
|
||||||
bs = 256/8
|
for i in range(math.ceil(len(plaintext)/bs)):
|
||||||
for i in range(math.ceil(len(plaintext/bs))):
|
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
m.update(xor(key, iv + i.to_bytes(bs, byteorder='big')))
|
m.update(xor(key, iv + i.to_bytes(bs, byteorder='big')))
|
||||||
k = m.digest()
|
k = m.digest()
|
||||||
ciphertext += sxor(k, plaintext[bs*i:][:bs])
|
ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0'))
|
||||||
|
return ciphertext
|
||||||
|
|
||||||
def dec(ciphertext, key, iv):
|
def dec(ciphertext, key, iv):
|
||||||
return enc(ciphertext, key, iv)
|
return enc(ciphertext, key, iv)
|
||||||
|
|
||||||
|
def getSample(length, src=None, key=b'VerySecureKeyMustKeepSecretDontTellAnyone'):
|
||||||
|
if src==None:
|
||||||
|
src = random.random() > 0.5
|
||||||
|
if not src:
|
||||||
|
r = os.urandom(length*bs)
|
||||||
|
return (r, 0)
|
||||||
|
else:
|
||||||
|
iv = random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big')
|
||||||
|
b = bytes(length*bs)
|
||||||
|
return (enc(b, key, iv), 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user