Fix: Encryption did not correctly increment iv

This commit is contained in:
Dominik Moritz Roth 2021-09-21 15:54:29 +02:00
parent 0c197e3e56
commit 42ed2dd676
2 changed files with 7 additions and 3 deletions

View File

@ -23,6 +23,7 @@ class Model(nn.Module):
output, state = self.lstm(x, prev_state) output, state = self.lstm(x, prev_state)
logits = self.fc(output) logits = self.fc(output)
val = self.out(logits) val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state return val, state
def init_state(self, sequence_length): def init_state(self, sequence_length):
@ -31,17 +32,18 @@ class Model(nn.Module):
def train(model, seq_len=16*64): def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid))
ltLoss = 50 ltLoss = 50
lltLoss = 51 lltLoss = 51
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.1) optimizer = optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(1024): for epoch in range(1024):
state_h, state_c = model.init_state(seq_len) state_h, state_c = model.init_state(seq_len)
blob, y = shark.getSample(seq_len, epoch%2) blob, y = shark.getSample(min(seq_len, 16*(epoch+1)), epoch%2)
optimizer.zero_grad() optimizer.zero_grad()
for i in range(len(blob)): for i in range(len(blob)):
x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32) x = torch.tensor([[[float(d) for d in bin(blob[i])[2:].zfill(8)]]], dtype=torch.float32)
@ -54,9 +56,10 @@ def train(model, seq_len=16*64):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
correct = round(y_pred.item()) == y
ltLoss = ltLoss*0.9 + 0.1*loss.item() ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss}) print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'correct?': correct})
if ltLoss < 0.20 and lltLoss < 0.225: if ltLoss < 0.20 and lltLoss < 0.225:
print("[*] Hell Yeah! Poccing! Got sup") print("[*] Hell Yeah! Poccing! Got sup")
if epoch % 8 == 0: if epoch % 8 == 0:

View File

@ -21,6 +21,7 @@ def enc(plaintext, key, iv):
m.update(xor(key, iv + i.to_bytes(bs, byteorder='big'))) m.update(xor(key, iv + i.to_bytes(bs, byteorder='big')))
k = m.digest() k = m.digest()
ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0')) ciphertext += xor(k, plaintext[bs*i:][:bs].ljust(bs, b'0'))
iv = (int.from_bytes(iv, byteorder='big')+1).to_bytes(bs, byteorder='big')
return ciphertext return ciphertext
def dec(ciphertext, key, iv): def dec(ciphertext, key, iv):