Updated README and better epoch-printing
This commit is contained in:
parent
5f0f9f1dce
commit
1b4b6f1a2f
@ -6,3 +6,4 @@ I made it because I want to try to break it.
|
|||||||
This will work iff I succeed in building a PPT-discriminator for sha256 from randomness
|
This will work iff I succeed in building a PPT-discriminator for sha256 from randomness
|
||||||
As my first approach this discriminator will be based on an LSTM-network.
|
As my first approach this discriminator will be based on an LSTM-network.
|
||||||
Update: This worked out way better than expected; given long enought sequences (128 Bytes are more than enough) we can discriminate successfully in 100% of cases.
|
Update: This worked out way better than expected; given long enought sequences (128 Bytes are more than enough) we can discriminate successfully in 100% of cases.
|
||||||
|
Update: I did an upsie in the training-code and the discriminator is actually shit.
|
||||||
|
5
train.py
5
train.py
@ -4,6 +4,7 @@ from torch import nn, optim
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
import shark
|
import shark
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ def train(model, seq_len=16*64):
|
|||||||
tid = str(int(random.random()*99999)).zfill(5)
|
tid = str(int(random.random()*99999)).zfill(5)
|
||||||
print("[i] I am "+str(tid))
|
print("[i] I am "+str(tid))
|
||||||
ltLoss = 50
|
ltLoss = 50
|
||||||
lltLoss = 51
|
lltLoss = 52
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
criterion = nn.BCELoss()
|
criterion = nn.BCELoss()
|
||||||
@ -67,7 +68,7 @@ def train(model, seq_len=16*64):
|
|||||||
correct[t] = round(y_pred.item()) == t
|
correct[t] = round(y_pred.item()) == t
|
||||||
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
ltLoss = ltLoss*0.9 + 0.1*loss.item()
|
||||||
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
lltLoss = lltLoss*0.9 + 0.1*ltLoss
|
||||||
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1] })
|
print({ 'epoch': epoch, 'loss': loss.item(), 'ltLoss': ltLoss, 'ok0': correct[0], 'ok1': correct[1], 'succ': correct[0] and correct[1], 'acc': str(int(max(0, 1-math.sqrt(lltLoss))*100))+"%" })
|
||||||
if epoch % 8 == 0:
|
if epoch % 8 == 0:
|
||||||
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
torch.save(model.state_dict(), 'model_savepoints/'+tid+'_'+str(epoch)+'.n')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user