Increased layer count of model

This commit is contained in:
Dominik Moritz Roth 2021-09-22 10:37:11 +02:00
parent 1b4b6f1a2f
commit 18ba954907
3 changed files with 28 additions and 23 deletions

View File

@ -8,6 +8,7 @@ import numpy as np
import random import random
import shark import shark
from model import Model
bs = int(256/8) bs = int(256/8)

26
model.py Normal file
View File

@ -0,0 +1,26 @@
import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=5,
dropout=0.01,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
return val, state
def init_state(self, sequence_length):
return (torch.zeros(5, 1, 16),
torch.zeros(5, 1, 16))

View File

@ -7,29 +7,7 @@ import random
import math import math
import shark import shark
from model import Model
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=3,
dropout=0.1,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
#print(str(logits.item())+" > "+str(val.item()))
return val, state
def init_state(self, sequence_length):
return (torch.zeros(3, 1, 16),
torch.zeros(3, 1, 16))
def train(model, seq_len=16*64): def train(model, seq_len=16*64):
tid = str(int(random.random()*99999)).zfill(5) tid = str(int(random.random()*99999)).zfill(5)