Updated README and better epoch-printing

This commit is contained in:
Dominik Moritz Roth 2021-09-22 10:28:08 +02:00
parent 5f0f9f1dce
commit 1b4b6f1a2f
2 changed files with 4 additions and 2 deletions

View File

@ -6,3 +6,4 @@ I made it because I want to try to break it.
This will work iff I succeed in building a PPT-discriminator for sha256 from randomness This will work iff I succeed in building a PPT-discriminator for sha256 from randomness
As my first approach this discriminator will be based on an LSTM-network. As my first approach this discriminator will be based on an LSTM-network.
Update: This worked out way better than expected; given long enought sequences (128 Bytes are more than enough) we can discriminate successfully in 100% of cases. Update: This worked out way better than expected; given long enought sequences (128 Bytes are more than enough) we can discriminate successfully in 100% of cases.
Update: I did an upsie in the training-code and the discriminator is actually shit.

View File

@ -4,6 +4,7 @@ from torch import nn, optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
import random import random
import math
import shark import shark
@ -34,7 +35,7 @@ def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)
print("[i] I am "+str(tid)) print("[i] I am "+str(tid))
ltLoss = 50 ltLoss = 50
lltLoss = 51 lltLoss = 52
model.train() model.train()
criterion = nn.BCELoss() criterion = nn.BCELoss()
@ -67,7 +68,7 @@ def train(model, seq_len=16*64):
correct[t] = round(y_pred.item()) == t correct[t] = round(y_pred.item()) == t
ltLoss = ltLoss*0.9 + 0.1*loss.item() ltLoss = ltLoss*0.9 + 0.1*loss.item()
lltLoss = lltLoss*0.9 + 0.1*ltLoss lltLoss = lltLoss*0.9 + 0.1*ltLoss
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1] }) print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
if epoch % 8 == 0: if epoch % 8 == 0:
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n') torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')