import torch from torch import nn from torch import nn, optim from torch.utils.data import DataLoader class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.lstm = nn.LSTM( input_size=8, hidden_size=16, num_layers=5, dropout=0.01, ) self.fc = nn.Linear(16, 1) self.out = nn.Sigmoid() def forward(self, x, prev_state): output, state = self.lstm(x, prev_state) logits = self.fc(output) val = self.out(logits) return val, state def init_state(self, sequence_length): return (torch.zeros(5, 1, 16), torch.zeros(5, 1, 16))