Implement Fourier Latent Projector
This commit is contained in:
parent
7808ba9464
commit
6076aaf36c
@ -112,13 +112,15 @@ name: RNN
|
|||||||
import: $
|
import: $
|
||||||
|
|
||||||
latent_projector:
|
latent_projector:
|
||||||
type: rnn # Options: 'fc', 'rnn'
|
type: rnn # Options: 'fc', 'rnn', 'fourier'
|
||||||
input_size: 1953 # =0.1s 19531 # =1s Input size for the Latent Projector (length of snippets).
|
input_size: 1953 # =0.1s 19531 # =1s Input size for the Latent Projector (length of snippets).
|
||||||
latent_size: 4 # Size of the latent representation before message passing.
|
latent_size: 4 # Size of the latent representation before message passing.
|
||||||
#layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc').
|
#layer_shapes: [32, 8] # List of layer sizes for the latent projector (if type is 'fc' or 'fourier').
|
||||||
#activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc').
|
#activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc' or 'fourier').
|
||||||
rnn_hidden_size: 3 # Hidden size for the RNN projector (if type is 'rnn').
|
rnn_hidden_size: 3 # Hidden size for the RNN projector (if type is 'rnn').
|
||||||
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
|
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
|
||||||
|
#num_frequencies: 16 # Number of frquency bins for the fourier decomp (if type is 'fourier').
|
||||||
|
#pass_raw_len: null # How many last samples to give raw to the net in addition to freqs (null = all) (if type is 'fourier').
|
||||||
|
|
||||||
middle_out:
|
middle_out:
|
||||||
region_latent_size: 4 # Size of the latent representation after message passing.
|
region_latent_size: 4 # Size of the latent representation after message passing.
|
||||||
|
6
main.py
6
main.py
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import random, math
|
import random, math
|
||||||
from utils import visualize_prediction, plot_delta_distribution
|
from utils import visualize_prediction, plot_delta_distribution
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
||||||
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
|
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
|
||||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
||||||
import wandb
|
import wandb
|
||||||
from pycallgraph2 import PyCallGraph
|
from pycallgraph2 import PyCallGraph
|
||||||
@ -51,9 +51,11 @@ class SpikeRunner(Slate_Runner):
|
|||||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||||
|
|
||||||
if latent_projector_type == 'fc':
|
if latent_projector_type == 'fc':
|
||||||
self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
elif latent_projector_type == 'rnn':
|
elif latent_projector_type == 'rnn':
|
||||||
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
|
elif latent_projector_type == 'fourier':
|
||||||
|
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
|
|
||||||
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
|
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
|
||||||
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
|
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
|
||||||
|
41
models.py
41
models.py
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.fft as fft
|
||||||
|
|
||||||
def get_activation(name):
|
def get_activation(name):
|
||||||
activations = {
|
activations = {
|
||||||
@ -12,9 +13,9 @@ def get_activation(name):
|
|||||||
}
|
}
|
||||||
return activations[name]()
|
return activations[name]()
|
||||||
|
|
||||||
class LatentProjector(nn.Module):
|
class LatentFCProjector(nn.Module):
|
||||||
def __init__(self, input_size, latent_size, layer_shapes, activations):
|
def __init__(self, input_size, latent_size, layer_shapes, activations):
|
||||||
super(LatentProjector, self).__init__()
|
super(LatentFCProjector, self).__init__()
|
||||||
layers = []
|
layers = []
|
||||||
in_features = input_size
|
in_features = input_size
|
||||||
for i, out_features in enumerate(layer_shapes):
|
for i, out_features in enumerate(layer_shapes):
|
||||||
@ -42,6 +43,42 @@ class LatentRNNProjector(nn.Module):
|
|||||||
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
class FourierTransformLayer(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x_fft = fft.rfft(x, dim=-1)
|
||||||
|
return x_fft
|
||||||
|
|
||||||
|
class LatentFourierProjector(nn.Module):
|
||||||
|
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
|
||||||
|
super(LatentFourierProjector, self).__init__()
|
||||||
|
self.fourier_transform = FourierTransformLayer()
|
||||||
|
layers = []
|
||||||
|
if pass_raw_len is None:
|
||||||
|
pass_raw_len = input_size
|
||||||
|
else:
|
||||||
|
assert pass_raw_len <= input_size
|
||||||
|
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
||||||
|
for i, out_features in enumerate(layer_shapes):
|
||||||
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
|
if activations[i] != 'None':
|
||||||
|
layers.append(get_activation(activations[i]))
|
||||||
|
in_features = out_features
|
||||||
|
layers.append(nn.Linear(in_features, latent_size))
|
||||||
|
self.fc = nn.Sequential(*layers)
|
||||||
|
self.latent_size = latent_size
|
||||||
|
self.pass_raw_len = pass_raw_len
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Apply Fourier Transform
|
||||||
|
x_fft = self.fourier_transform(x)
|
||||||
|
# Separate real and imaginary parts and combine them
|
||||||
|
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
||||||
|
# Combine part of the raw input with Fourier features
|
||||||
|
combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
||||||
|
# Process through fully connected layers
|
||||||
|
latent = self.fc(combined_input)
|
||||||
|
return latent
|
||||||
|
|
||||||
class MiddleOut(nn.Module):
|
class MiddleOut(nn.Module):
|
||||||
def __init__(self, latent_size, region_latent_size, num_peers):
|
def __init__(self, latent_size, region_latent_size, num_peers):
|
||||||
super(MiddleOut, self).__init__()
|
super(MiddleOut, self).__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user