From 6076aaf36c6e155e90a5e80f4a20891e5bb08218 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 26 May 2024 15:40:00 +0200 Subject: [PATCH] Implement Fourier Latent Projector --- config.yaml | 8 +++++--- main.py | 6 ++++-- models.py | 41 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/config.yaml b/config.yaml index 06cfdfc..44e7c3e 100644 --- a/config.yaml +++ b/config.yaml @@ -112,13 +112,15 @@ name: RNN import: $ latent_projector: - type: rnn # Options: 'fc', 'rnn' + type: rnn # Options: 'fc', 'rnn', 'fourier' input_size: 1953 # =0.1s 19531 # =1s Input size for the Latent Projector (length of snippets). latent_size: 4 # Size of the latent representation before message passing. - #layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc'). - #activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc'). + #layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc' or 'fourier'). + #activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc' or 'fourier'). rnn_hidden_size: 3 # Hidden size for the RNN projector (if type is 'rnn'). rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn'). + #num_frequencies: 16 # Number of frquency bins for the fourier decomp (if type is 'fourier'). + #pass_raw_len: null # How many last samples to give raw to the net in addition to freqs (null = all) (if type is 'fourier'). middle_out: region_latent_size: 4 # Size of the latent representation after message passing. diff --git a/main.py b/main.py index ee7a1a6..04e0202 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import numpy as np import random, math from utils import visualize_prediction, plot_delta_distribution from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics -from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor +from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder import wandb from pycallgraph2 import PyCallGraph @@ -51,9 +51,11 @@ class SpikeRunner(Slate_Runner): region_latent_size = slate.consume(config, 'middle_out.region_latent_size') if latent_projector_type == 'fc': - self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) + self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) elif latent_projector_type == 'rnn': self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) + elif latent_projector_type == 'fourier': + self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device) diff --git a/models.py b/models.py index 91926c0..7f5e98e 100644 --- a/models.py +++ b/models.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.fft as fft def get_activation(name): activations = { @@ -12,9 +13,9 @@ def get_activation(name): } return activations[name]() -class LatentProjector(nn.Module): +class LatentFCProjector(nn.Module): def __init__(self, input_size, latent_size, layer_shapes, activations): - super(LatentProjector, self).__init__() + super(LatentFCProjector, self).__init__() layers = [] in_features = input_size for i, out_features in enumerate(layer_shapes): @@ -42,6 +43,42 @@ class LatentRNNProjector(nn.Module): latent = self.fc(out).view(batch_1, batch_2, self.latent_size) return latent +class FourierTransformLayer(nn.Module): + def forward(self, x): + x_fft = fft.rfft(x, dim=-1) + return x_fft + +class LatentFourierProjector(nn.Module): + def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None): + super(LatentFourierProjector, self).__init__() + self.fourier_transform = FourierTransformLayer() + layers = [] + if pass_raw_len is None: + pass_raw_len = input_size + else: + assert pass_raw_len <= input_size + in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts + for i, out_features in enumerate(layer_shapes): + layers.append(nn.Linear(in_features, out_features)) + if activations[i] != 'None': + layers.append(get_activation(activations[i])) + in_features = out_features + layers.append(nn.Linear(in_features, latent_size)) + self.fc = nn.Sequential(*layers) + self.latent_size = latent_size + self.pass_raw_len = pass_raw_len + + def forward(self, x): + # Apply Fourier Transform + x_fft = self.fourier_transform(x) + # Separate real and imaginary parts and combine them + x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1) + # Combine part of the raw input with Fourier features + combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1) + # Process through fully connected layers + latent = self.fc(combined_input) + return latent + class MiddleOut(nn.Module): def __init__(self, latent_size, region_latent_size, num_peers): super(MiddleOut, self).__init__()