2024-05-25 17:31:08 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2024-05-26 15:40:00 +02:00
|
|
|
import torch.fft as fft
|
2024-05-27 17:00:02 +02:00
|
|
|
import pywt
|
2024-05-25 17:31:08 +02:00
|
|
|
|
|
|
|
def get_activation(name):
|
|
|
|
activations = {
|
|
|
|
'ReLU': nn.ReLU,
|
|
|
|
'Sigmoid': nn.Sigmoid,
|
|
|
|
'Tanh': nn.Tanh,
|
|
|
|
'LeakyReLU': nn.LeakyReLU,
|
|
|
|
'ELU': nn.ELU,
|
|
|
|
'None': nn.Identity
|
|
|
|
}
|
|
|
|
return activations[name]()
|
|
|
|
|
2024-05-27 17:00:02 +02:00
|
|
|
class FeatureExtractor(nn.Module):
|
|
|
|
def __init__(self, input_size, transforms):
|
|
|
|
super(FeatureExtractor, self).__init__()
|
|
|
|
self.input_size = input_size
|
|
|
|
self.transforms = self.build_transforms(transforms)
|
|
|
|
|
|
|
|
def build_transforms(self, config):
|
|
|
|
transforms = []
|
|
|
|
for item in config:
|
|
|
|
transform_type = item['type']
|
|
|
|
length = item.get('length', self.input_size)
|
|
|
|
if length in [None, -1]:
|
|
|
|
length = self.input_size
|
|
|
|
|
|
|
|
if transform_type == 'identity':
|
|
|
|
transforms.append(('identity', length))
|
|
|
|
elif transform_type == 'fourier':
|
|
|
|
transforms.append(('fourier', length))
|
|
|
|
elif transform_type == 'wavelet':
|
|
|
|
wavelet_type = item['wavelet_type']
|
|
|
|
transforms.append(('wavelet', wavelet_type, length))
|
|
|
|
return transforms
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
batch_1, batch_2, timesteps = x.size()
|
|
|
|
x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing
|
|
|
|
outputs = []
|
|
|
|
for transform in self.transforms:
|
|
|
|
if transform[0] == 'identity':
|
|
|
|
_, length = transform
|
|
|
|
outputs.append(x[:, -length:])
|
|
|
|
elif transform[0] == 'fourier':
|
|
|
|
_, length = transform
|
|
|
|
fourier_transform = fft.fft(x[:, -length:], dim=1)
|
|
|
|
fourier_real = fourier_transform.real
|
|
|
|
fourier_imag = fourier_transform.imag
|
|
|
|
outputs.append(fourier_real)
|
|
|
|
outputs.append(fourier_imag)
|
|
|
|
elif transform[0] == 'wavelet':
|
|
|
|
_, wavelet_type, length = transform
|
|
|
|
coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type)
|
|
|
|
wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs]
|
|
|
|
wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1)
|
|
|
|
outputs.append(wavelet_coeffs)
|
|
|
|
concatenated_outputs = torch.cat(outputs, dim=1)
|
|
|
|
concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions
|
|
|
|
return concatenated_outputs
|
|
|
|
|
|
|
|
def compute_output_size(self):
|
|
|
|
size = 0
|
|
|
|
for transform in self.transforms:
|
|
|
|
if transform[0] == 'identity':
|
|
|
|
_, length = transform
|
|
|
|
size += length
|
|
|
|
elif transform[0] == 'fourier':
|
|
|
|
_, length = transform
|
|
|
|
size += length * 2 # Fourier transform outputs both real and imaginary parts
|
|
|
|
elif transform[0] == 'wavelet':
|
|
|
|
_, wavelet_type, length = transform
|
|
|
|
# Find the true size of the wavelet coefficients
|
|
|
|
test_signal = torch.zeros(length)
|
|
|
|
coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type)
|
|
|
|
wavelet_size = sum(len(c) for c in coeffs)
|
|
|
|
size += wavelet_size
|
|
|
|
return size
|
|
|
|
|
2024-05-26 15:40:00 +02:00
|
|
|
class LatentFCProjector(nn.Module):
|
2024-05-27 17:00:02 +02:00
|
|
|
def __init__(self, feature_size, latent_size, layer_shapes, activations):
|
2024-05-26 15:40:00 +02:00
|
|
|
super(LatentFCProjector, self).__init__()
|
2024-05-25 17:31:08 +02:00
|
|
|
layers = []
|
2024-05-27 17:00:02 +02:00
|
|
|
in_features = feature_size
|
2024-05-25 17:31:08 +02:00
|
|
|
for i, out_features in enumerate(layer_shapes):
|
|
|
|
layers.append(nn.Linear(in_features, out_features))
|
|
|
|
if activations[i] != 'None':
|
|
|
|
layers.append(get_activation(activations[i]))
|
|
|
|
in_features = out_features
|
|
|
|
layers.append(nn.Linear(in_features, latent_size))
|
|
|
|
self.fc = nn.Sequential(*layers)
|
|
|
|
self.latent_size = latent_size
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.fc(x)
|
|
|
|
|
|
|
|
class LatentRNNProjector(nn.Module):
|
2024-05-27 17:00:02 +02:00
|
|
|
def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
2024-05-25 17:31:08 +02:00
|
|
|
super(LatentRNNProjector, self).__init__()
|
2024-05-27 17:00:02 +02:00
|
|
|
self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
2024-05-25 17:31:08 +02:00
|
|
|
self.fc = nn.Linear(rnn_hidden_size, latent_size)
|
|
|
|
self.latent_size = latent_size
|
|
|
|
|
|
|
|
def forward(self, x):
|
2024-05-26 00:28:18 +02:00
|
|
|
batch_1, batch_2, timesteps = x.size()
|
|
|
|
out, _ = self.rnn(x.view(batch_1 * batch_2, timesteps))
|
|
|
|
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
2024-05-25 17:31:08 +02:00
|
|
|
return latent
|
|
|
|
|
|
|
|
class MiddleOut(nn.Module):
|
2024-05-26 15:58:45 +02:00
|
|
|
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
2024-05-25 17:31:08 +02:00
|
|
|
super(MiddleOut, self).__init__()
|
2024-05-26 15:58:45 +02:00
|
|
|
if residual:
|
|
|
|
assert latent_size == region_latent_size
|
2024-05-26 17:41:30 +02:00
|
|
|
if num_peers == 0:
|
|
|
|
assert latent_size == region_latent_size
|
2024-05-25 17:31:08 +02:00
|
|
|
self.num_peers = num_peers
|
2024-05-26 13:56:59 +02:00
|
|
|
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
2024-05-26 15:58:45 +02:00
|
|
|
self.residual = residual
|
2024-05-25 17:31:08 +02:00
|
|
|
|
2024-05-26 13:48:30 +02:00
|
|
|
def forward(self, my_latent, peer_latents, peer_metrics):
|
2024-05-26 17:41:30 +02:00
|
|
|
if self.num_peers == 0:
|
|
|
|
return my_latent
|
2024-05-25 17:31:08 +02:00
|
|
|
new_latents = []
|
2024-05-25 20:27:54 +02:00
|
|
|
for p in range(peer_latents.shape[-2]):
|
2024-05-26 13:48:30 +02:00
|
|
|
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
|
|
|
combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1)
|
2024-05-25 17:31:08 +02:00
|
|
|
new_latent = self.fc(combined_input)
|
2024-05-26 23:56:02 +02:00
|
|
|
if self.residual:
|
|
|
|
new_latent = new_latent * metric.unsqueeze(1)
|
|
|
|
new_latents.append(new_latent)
|
2024-05-25 17:31:08 +02:00
|
|
|
|
|
|
|
new_latents = torch.stack(new_latents)
|
|
|
|
averaged_latent = torch.mean(new_latents, dim=0)
|
2024-05-26 15:58:45 +02:00
|
|
|
if self.residual:
|
|
|
|
return my_latent - averaged_latent
|
2024-05-26 00:28:18 +02:00
|
|
|
return averaged_latent
|
2024-05-25 17:31:08 +02:00
|
|
|
|
|
|
|
class Predictor(nn.Module):
|
2024-05-26 13:56:59 +02:00
|
|
|
def __init__(self, region_latent_size, layer_shapes, activations):
|
2024-05-25 17:31:08 +02:00
|
|
|
super(Predictor, self).__init__()
|
|
|
|
layers = []
|
2024-05-26 13:56:59 +02:00
|
|
|
in_features = region_latent_size
|
2024-05-25 17:31:08 +02:00
|
|
|
for i, out_features in enumerate(layer_shapes):
|
|
|
|
layers.append(nn.Linear(in_features, out_features))
|
|
|
|
if activations[i] != 'None':
|
|
|
|
layers.append(get_activation(activations[i]))
|
|
|
|
in_features = out_features
|
|
|
|
layers.append(nn.Linear(in_features, 1))
|
|
|
|
self.fc = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, latent):
|
2024-05-26 00:28:18 +02:00
|
|
|
return self.fc(latent)
|