Implement Fourier Latent Projector

This commit is contained in:
Dominik Moritz Roth 2024-05-26 15:40:00 +02:00
parent 7808ba9464
commit 6076aaf36c
3 changed files with 48 additions and 7 deletions

View File

@ -112,13 +112,15 @@ name: RNN
import: $ import: $
latent_projector: latent_projector:
type: rnn # Options: 'fc', 'rnn' type: rnn # Options: 'fc', 'rnn', 'fourier'
input_size: 1953 # =0.1s 19531 # =1s Input size for the Latent Projector (length of snippets). input_size: 1953 # =0.1s 19531 # =1s Input size for the Latent Projector (length of snippets).
latent_size: 4 # Size of the latent representation before message passing. latent_size: 4 # Size of the latent representation before message passing.
#layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc'). #layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc' or 'fourier').
#activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc'). #activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc' or 'fourier').
rnn_hidden_size: 3 # Hidden size for the RNN projector (if type is 'rnn'). rnn_hidden_size: 3 # Hidden size for the RNN projector (if type is 'rnn').
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn'). rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
#num_frequencies: 16 # Number of frquency bins for the fourier decomp (if type is 'fourier').
#pass_raw_len: null # How many last samples to give raw to the net in addition to freqs (null = all) (if type is 'fourier').
middle_out: middle_out:
region_latent_size: 4 # Size of the latent representation after message passing. region_latent_size: 4 # Size of the latent representation after message passing.

View File

@ -5,7 +5,7 @@ import numpy as np
import random, math import random, math
from utils import visualize_prediction, plot_delta_distribution from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
import wandb import wandb
from pycallgraph2 import PyCallGraph from pycallgraph2 import PyCallGraph
@ -51,9 +51,11 @@ class SpikeRunner(Slate_Runner):
region_latent_size = slate.consume(config, 'middle_out.region_latent_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
if latent_projector_type == 'fc': if latent_projector_type == 'fc':
self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn': elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'fourier':
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device) self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)

View File

@ -1,5 +1,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.fft as fft
def get_activation(name): def get_activation(name):
activations = { activations = {
@ -12,9 +13,9 @@ def get_activation(name):
} }
return activations[name]() return activations[name]()
class LatentProjector(nn.Module): class LatentFCProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations): def __init__(self, input_size, latent_size, layer_shapes, activations):
super(LatentProjector, self).__init__() super(LatentFCProjector, self).__init__()
layers = [] layers = []
in_features = input_size in_features = input_size
for i, out_features in enumerate(layer_shapes): for i, out_features in enumerate(layer_shapes):
@ -42,6 +43,42 @@ class LatentRNNProjector(nn.Module):
latent = self.fc(out).view(batch_1, batch_2, self.latent_size) latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
return latent return latent
class FourierTransformLayer(nn.Module):
def forward(self, x):
x_fft = fft.rfft(x, dim=-1)
return x_fft
class LatentFourierProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
super(LatentFourierProjector, self).__init__()
self.fourier_transform = FourierTransformLayer()
layers = []
if pass_raw_len is None:
pass_raw_len = input_size
else:
assert pass_raw_len <= input_size
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
self.pass_raw_len = pass_raw_len
def forward(self, x):
# Apply Fourier Transform
x_fft = self.fourier_transform(x)
# Separate real and imaginary parts and combine them
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
# Combine part of the raw input with Fourier features
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
# Process through fully connected layers
latent = self.fc(combined_input)
return latent
class MiddleOut(nn.Module): class MiddleOut(nn.Module):
def __init__(self, latent_size, region_latent_size, num_peers): def __init__(self, latent_size, region_latent_size, num_peers):
super(MiddleOut, self).__init__() super(MiddleOut, self).__init__()