Fixed a problem in model-training and renamed files

This commit is contained in:
Dominik Moritz Roth 2021-09-22 09:14:23 +02:00
parent 30cc846c6f
commit 5f0f9f1dce
3 changed files with 105 additions and 16 deletions

78
discriminate.py Normal file
View File

@ -0,0 +1,78 @@
import os
import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import random
import shark
bs = int(256/8)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def run(model, seq):
state_h, state_c = model.init_state(len(seq))
for i in range(len(seq)):
x = torch.tensor([[[float(d) for d in bin(seq[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
#state_h = state_h.detach()
#state_c = state_c.detach()
return y_pred.item()
def score(model, ciphertext, hypothesis):
seq = shark.xor(ciphertext, hypothesis)
return run(model, seq)
def test_scoring(model):
length = 16
iv = shark.genIV()
# TODO: Generate human language
plaintext = os.urandom(length*bs)
plaintextAlt = os.urandom(length*bs)
plaintextSemi = [plaintext[s] if random.random()>0.5 else plaintextAlt[s] for s in range(length)]
ciphertext = shark.enc(plaintext, b'VerySecureKeyMustKeepSecretDontTellAnyone', iv)
high = score(model, ciphertext, plaintext)
low = score(model, ciphertext, plaintextAlt)
mid = score(model, ciphertext, plaintextSemi)
print({'h': high, 'l': low, 'm': mid})
def load():
model = Model()
model.load_state_dict(torch.load('wh_discriminator.n'))
model.eval()
return model
if __name__=="__main__":
m = load()
test_scoring(m)

View File

@ -14,14 +14,17 @@ bs = int(256/8)
def xor(ta,tb):
return bytes(a ^ b for a, b in zip(ta, tb))
def genIV():
return random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big')
def enc(plaintext, key, iv):
ciphertext = bytes()
for i in range(math.ceil(len(plaintext)/bs)):
m = hashlib.sha256()
m.update(xor(key, iv + i.to_bytes(bs, byteorder='big')))
k = m.digest()
ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0'))
iv = (int.from_bytes(iv, byteorder='big')+1).to_bytes(bs, byteorder='big')
ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0'))
return ciphertext
def dec(ciphertext, key, iv):
@ -34,6 +37,6 @@ def getSample(length, src=None, key=b'VerySecureKeyMustKeepSecretDontTellAnyone'
r = os.urandom(length*bs)
return (r, 0)
else:
iv = random.randint(0, 2**(bs-1)).to_bytes(bs, byteorder='big')
iv = genIV()
b = bytes(length*bs)
return (enc(b, key, iv), 1)

View File

@ -40,26 +40,34 @@ def train(model, seq_len=16*64):
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
state_h = [None,None]
state_c = [None,None]
blob = [None,None]
correct = [None,None]
for epoch in range(1024):
state_h, state_c = model.init_state(seq_len)
state_h[0], state_c[0] = model.init_state(seq_len)
state_h[1], state_c[1] = model.init_state(seq_len)
blob, y = shark.getSample(min(seq_len, 16*(epoch+1)), epoch%2)
blob[0], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 0)
blob[1], _ = shark.getSample(min(seq_len, 16*(epoch+1)), 1)
optimizer.zero_grad()
for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred[0][0][0], torch.tensor(y, dtype=torch.float32))
for i in range(len(blob[0])):
for t in range(2):
x = torch.tensor([[[float(d) for d in bin(blob[t][i])[2:].zfill(8)]]], dtype=torch.float32)
y_pred, (state_h[t], state_c[t]) = model(x, (state_h[t], state_c[t]))
loss = criterion(y_pred[0][0][0], torch.tensor(t, dtype=torch.float32))
state_h = state_h.detach()
state_c = state_c.detach()
state_h[t] = state_h[t].detach()
state_c[t] = state_c[t].detach()
loss.backward()
optimizer.step()
correct = round(y_pred.item()) == y
correct[t] = round(y_pred.item()) == t
ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'correct?': correct })
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1] })
if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')