27 lines
701 B
Python
27 lines
701 B
Python
|
import torch
|
||
|
from torch import nn
|
||
|
from torch import nn, optim
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
class Model(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Model, self).__init__()
|
||
|
self.lstm = nn.LSTM(
|
||
|
input_size=8,
|
||
|
hidden_size=16,
|
||
|
num_layers=5,
|
||
|
dropout=0.01,
|
||
|
)
|
||
|
self.fc = nn.Linear(16, 1)
|
||
|
self.out = nn.Sigmoid()
|
||
|
|
||
|
def forward(self, x, prev_state):
|
||
|
output, state = self.lstm(x, prev_state)
|
||
|
logits = self.fc(output)
|
||
|
val = self.out(logits)
|
||
|
return val, state
|
||
|
|
||
|
def init_state(self, sequence_length):
|
||
|
return (torch.zeros(5, 1, 16),
|
||
|
torch.zeros(5, 1, 16))
|