From 18ba9549075d436a5e512ceb91c6352803ef9386 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 22 Sep 2021 10:37:11 +0200 Subject: [PATCH] Increased layer count of model --- discriminate.py | 1 + model.py | 26 ++++++++++++++++++++++++++ train.py | 24 +----------------------- 3 files changed, 28 insertions(+), 23 deletions(-) create mode 100644 model.py diff --git a/discriminate.py b/discriminate.py index f048a1d..c8fadf8 100644 --- a/discriminate.py +++ b/discriminate.py @@ -8,6 +8,7 @@ import numpy as np import random import shark +from model import Model bs = int(256/8) diff --git a/model.py b/model.py new file mode 100644 index 0000000..a42fec5 --- /dev/null +++ b/model.py @@ -0,0 +1,26 @@ +import torch +from torch import nn +from torch import nn, optim +from torch.utils.data import DataLoader + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.lstm = nn.LSTM( + input_size=8, + hidden_size=16, + num_layers=5, + dropout=0.01, + ) + self.fc = nn.Linear(16, 1) + self.out = nn.Sigmoid() + + def forward(self, x, prev_state): + output, state = self.lstm(x, prev_state) + logits = self.fc(output) + val = self.out(logits) + return val, state + + def init_state(self, sequence_length): + return (torch.zeros(5, 1, 16), + torch.zeros(5, 1, 16)) diff --git a/train.py b/train.py index 4bb2ac8..4d33f32 100644 --- a/train.py +++ b/train.py @@ -7,29 +7,7 @@ import random import math import shark - -class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.lstm = nn.LSTM( - input_size=8, - hidden_size=16, - num_layers=3, - dropout=0.1, - ) - self.fc = nn.Linear(16, 1) - self.out = nn.Sigmoid() - - def forward(self, x, prev_state): - output, state = self.lstm(x, prev_state) - logits = self.fc(output) - val = self.out(logits) - #print(str(logits.item())+" > "+str(val.item())) - return val, state - - def init_state(self, sequence_length): - return (torch.zeros(3, 1, 16), - torch.zeros(3, 1, 16)) +from model import Model def train(model, seq_len=16*64): tid = str(int(random.random()*99999)).zfill(5)