2024-05-24 22:01:59 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
2024-05-25 00:53:30 +02:00
|
|
|
class BaseModel(nn.Module):
|
2024-05-24 22:01:59 +02:00
|
|
|
def __init__(self):
|
2024-05-24 23:02:24 +02:00
|
|
|
super(BaseModel, self).__init__()
|
2024-05-24 22:01:59 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def forward(self, x):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def encode(self, data):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def decode(self, encoded_data):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class LSTMPredictor(BaseModel):
|
|
|
|
def __init__(self, input_size, hidden_size, num_layers):
|
|
|
|
super(LSTMPredictor, self).__init__()
|
|
|
|
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
|
|
|
self.fc = nn.Linear(hidden_size, 1)
|
2024-05-25 00:53:30 +02:00
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2024-05-24 22:01:59 +02:00
|
|
|
|
|
|
|
def forward(self, x):
|
2024-05-25 00:53:30 +02:00
|
|
|
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
|
|
|
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
2024-05-24 22:01:59 +02:00
|
|
|
out, _ = self.rnn(x, (h0, c0))
|
|
|
|
out = self.fc(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
def encode(self, data):
|
|
|
|
self.eval()
|
|
|
|
encoded_data = []
|
|
|
|
|
2024-05-25 00:53:30 +02:00
|
|
|
context_size = self.hidden_size # Define an appropriate context size
|
2024-05-24 22:01:59 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(len(data) - 1):
|
2024-05-25 00:53:30 +02:00
|
|
|
context = torch.tensor(data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
|
|
|
if context.size(1) == 0: # Handle empty context
|
|
|
|
continue
|
|
|
|
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
2024-05-24 22:01:59 +02:00
|
|
|
delta = data[i] - prediction
|
|
|
|
encoded_data.append(delta)
|
|
|
|
|
|
|
|
return encoded_data
|
|
|
|
|
|
|
|
def decode(self, encoded_data):
|
|
|
|
self.eval()
|
|
|
|
decoded_data = []
|
|
|
|
|
2024-05-25 00:53:30 +02:00
|
|
|
context_size = self.hidden_size # Define an appropriate context size
|
2024-05-24 22:01:59 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(len(encoded_data)):
|
2024-05-25 00:53:30 +02:00
|
|
|
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
|
|
|
if context.size(1) == 0: # Handle empty context
|
|
|
|
continue
|
|
|
|
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
2024-05-24 22:01:59 +02:00
|
|
|
decoded_data.append(prediction + encoded_data[i])
|
|
|
|
|
|
|
|
return decoded_data
|
|
|
|
|
|
|
|
class FixedInputNNPredictor(BaseModel):
|
|
|
|
def __init__(self, input_size, hidden_size):
|
|
|
|
super(FixedInputNNPredictor, self).__init__()
|
|
|
|
self.fc1 = nn.Linear(input_size, hidden_size)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
self.fc2 = nn.Linear(hidden_size, 1)
|
2024-05-25 00:53:30 +02:00
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
2024-05-24 22:01:59 +02:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.relu(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def encode(self, data):
|
|
|
|
self.eval()
|
|
|
|
encoded_data = []
|
|
|
|
|
2024-05-25 00:53:30 +02:00
|
|
|
context_size = self.fc1.in_features # Define an appropriate context size
|
2024-05-24 22:01:59 +02:00
|
|
|
with torch.no_grad():
|
2024-05-25 00:53:30 +02:00
|
|
|
for i in range(len(data) - context_size):
|
|
|
|
context = torch.tensor(data[i:i + context_size]).reshape(1, -1).to(self.device)
|
|
|
|
if context.size(1) == 0: # Handle empty context
|
|
|
|
continue
|
|
|
|
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
|
|
|
delta = data[i + context_size] - prediction
|
2024-05-24 22:01:59 +02:00
|
|
|
encoded_data.append(delta)
|
|
|
|
|
|
|
|
return encoded_data
|
|
|
|
|
|
|
|
def decode(self, encoded_data):
|
|
|
|
self.eval()
|
|
|
|
decoded_data = []
|
|
|
|
|
2024-05-25 00:53:30 +02:00
|
|
|
context_size = self.fc1.in_features # Define an appropriate context size
|
2024-05-24 22:01:59 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
for i in range(len(encoded_data)):
|
2024-05-25 00:53:30 +02:00
|
|
|
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1).to(self.device)
|
|
|
|
if context.size(1) == 0: # Handle empty context
|
|
|
|
continue
|
|
|
|
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
2024-05-24 22:01:59 +02:00
|
|
|
decoded_data.append(prediction + encoded_data[i])
|
|
|
|
|
|
|
|
return decoded_data
|