Spikey/models.py

119 lines
4.7 KiB
Python

import torch
import torch.nn as nn
import torch.fft as fft
def get_activation(name):
activations = {
'ReLU': nn.ReLU,
'Sigmoid': nn.Sigmoid,
'Tanh': nn.Tanh,
'LeakyReLU': nn.LeakyReLU,
'ELU': nn.ELU,
'None': nn.Identity
}
return activations[name]()
class LatentFCProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations):
super(LatentFCProjector, self).__init__()
layers = []
in_features = input_size
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
def forward(self, x):
return self.fc(x)
class LatentRNNProjector(nn.Module):
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
super(LatentRNNProjector, self).__init__()
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
self.fc = nn.Linear(rnn_hidden_size, latent_size)
self.latent_size = latent_size
def forward(self, x):
batch_1, batch_2, timesteps = x.size()
out, _ = self.rnn(x.view(batch_1 * batch_2, timesteps))
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
return latent
class FourierTransformLayer(nn.Module):
def forward(self, x):
x_fft = fft.rfft(x, dim=-1)
return x_fft
class LatentFourierProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
super(LatentFourierProjector, self).__init__()
self.fourier_transform = FourierTransformLayer()
layers = []
if pass_raw_len is None:
pass_raw_len = input_size
else:
assert pass_raw_len <= input_size
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
self.pass_raw_len = pass_raw_len
def forward(self, x):
# Apply Fourier Transform
x_fft = self.fourier_transform(x)
# Separate real and imaginary parts and combine them
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
# Combine part of the raw input with Fourier features
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
# Process through fully connected layers
latent = self.fc(combined_input)
return latent
class MiddleOut(nn.Module):
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
super(MiddleOut, self).__init__()
if residual:
assert latent_size == region_latent_size
self.num_peers = num_peers
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
self.residual = residual
def forward(self, my_latent, peer_latents, peer_metrics):
new_latents = []
for p in range(peer_latents.shape[-2]):
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1)
new_latent = self.fc(combined_input)
new_latents.append(new_latent * metric.unsqueeze(1))
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
if self.residual:
return my_latent - averaged_latent
return averaged_latent
class Predictor(nn.Module):
def __init__(self, region_latent_size, layer_shapes, activations):
super(Predictor, self).__init__()
layers = []
in_features = region_latent_size
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, 1))
self.fc = nn.Sequential(*layers)
def forward(self, latent):
return self.fc(latent)