Increased layer count of model
This commit is contained in:
parent
1b4b6f1a2f
commit
18ba954907
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import shark
|
import shark
|
||||||
|
from model import Model
|
||||||
|
|
||||||
bs = int(256/8)
|
bs = int(256/8)
|
||||||
|
|
||||||
|
26
model.py
Normal file
26
model.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import nn, optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=8,
|
||||||
|
hidden_size=16,
|
||||||
|
num_layers=5,
|
||||||
|
dropout=0.01,
|
||||||
|
)
|
||||||
|
self.fc = nn.Linear(16, 1)
|
||||||
|
self.out = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x, prev_state):
|
||||||
|
output, state = self.lstm(x, prev_state)
|
||||||
|
logits = self.fc(output)
|
||||||
|
val = self.out(logits)
|
||||||
|
return val, state
|
||||||
|
|
||||||
|
def init_state(self, sequence_length):
|
||||||
|
return (torch.zeros(5, 1, 16),
|
||||||
|
torch.zeros(5, 1, 16))
|
24
train.py
24
train.py
@ -7,29 +7,7 @@ import random
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import shark
|
import shark
|
||||||
|
from model import Model
|
||||||
class Model(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(Model, self).__init__()
|
|
||||||
self.lstm = nn.LSTM(
|
|
||||||
input_size=8,
|
|
||||||
hidden_size=16,
|
|
||||||
num_layers=3,
|
|
||||||
dropout=0.1,
|
|
||||||
)
|
|
||||||
self.fc = nn.Linear(16, 1)
|
|
||||||
self.out = nn.Sigmoid()
|
|
||||||
|
|
||||||
def forward(self, x, prev_state):
|
|
||||||
output, state = self.lstm(x, prev_state)
|
|
||||||
logits = self.fc(output)
|
|
||||||
val = self.out(logits)
|
|
||||||
#print(str(logits.item())+" > "+str(val.item()))
|
|
||||||
return val, state
|
|
||||||
|
|
||||||
def init_state(self, sequence_length):
|
|
||||||
return (torch.zeros(3, 1, 16),
|
|
||||||
torch.zeros(3, 1, 16))
|
|
||||||
|
|
||||||
def train(model, seq_len=16*64):
|
def train(model, seq_len=16*64):
|
||||||
tid = str(int(random.random()*99999)).zfill(5)
|
tid = str(int(random.random()*99999)).zfill(5)
|
||||||
|
Loading…
Reference in New Issue
Block a user