shark/model.py

29 lines
715 B
Python
Raw Permalink Normal View History

import torch
from torch import nn
from torch import nn, optim
from torch.utils.data import DataLoader
import shark
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(
input_size=8,
hidden_size=16,
num_layers=5,
dropout=0.01,
)
self.fc = nn.Linear(16, 1)
self.out = nn.Sigmoid()
def forward(self, x, prev_state):
output, state = self.lstm(x, prev_state)
logits = self.fc(output)
val = self.out(logits)
return val, state
def init_state(self, sequence_length):
return (torch.zeros(5, 1, 16),
torch.zeros(5, 1, 16))