Feature Extraction

This commit is contained in:
Dominik Moritz Roth 2024-05-27 17:00:02 +02:00
parent d9d0e2b8c8
commit eefacf884d
3 changed files with 126 additions and 58 deletions

View File

@ -1,5 +1,38 @@
name: EXAMPLE
feature_extractor:
- type: 'identity' # Pass the last n samples of the input data directly.
length: 8 # Number of last samples to pass directly. Use full input size if set to null.
- type: 'fourier' # Apply Fourier transform to the input data.
length: null # Use full input size if set to null. Fourier transform outputs both real and imaginary parts, doubling the size. (Computationally expensive)
- type: 'wavelet' # Apply selected wavelet transform to the input data.
wavelet_type: 'haar' # Haar wavelet is simple and fast, but may not capture detailed features well.
length: null # Use full input size if set to null.
- type: 'wavelet'
wavelet_type: 'cgau1' # Complex Gaussian wavelets are used for complex-valued signal analysis and capturing phase information.
length: null # Use full input size if set to null.
- type: 'wavelet'
wavelet_type: 'db1' # Daubechies wavelets provide a balance between time and frequency localization.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'sym2' # Symlet wavelets are nearly symmetrical, offering improved phase characteristics over Daubechies.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'coif1' # Coiflet wavelets have more vanishing moments, suitable for capturing polynomial trends.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'bior1.3' # Biorthogonal wavelets provide perfect reconstruction and linear phase characteristics.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'rbio1.3' # Reverse Biorthogonal wavelets are similar to Biorthogonal but optimized for different applications.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'dmey' # Discrete Meyer wavelets offer good frequency localization, ideal for signals with oscillatory components.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'morl' # Morlet wavelets are useful for time-frequency analysis due to their Gaussian-modulated sinusoid shape.
length: null # Use full input size if set to null. (Computationally expensive)
latent_projector:
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)

39
main.py
View File

@ -4,8 +4,8 @@ import torch.nn as nn
import numpy as np
import random, math
from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
import wandb
from pycallgraph2 import PyCallGraph
@ -26,10 +26,10 @@ class SpikeRunner(Slate_Runner):
data_url = slate.consume(data_config, 'url')
cut_length = slate.consume(data_config, 'cut_length', None)
download_and_extract_data(data_url)
all_data = load_all_wavs('data', cut_length)
self.all_data = load_all_wavs('data', cut_length)
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
print("Reconstructing thread topology")
self.topology_matrix = compute_topology_metrics(self.train_data)
@ -52,12 +52,13 @@ class SpikeRunner(Slate_Runner):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
feature_size = self.feat.compute_output_size()
if latent_projector_type == 'fc':
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'fourier':
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
else:
raise Exception('No such Latent Projector')
@ -284,15 +285,8 @@ class SpikeRunner(Slate_Runner):
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
if self.full_compression:
self.encoder.build_model(my_latent.cpu().numpy())
compressed_data = self.encoder.encode(my_latent.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
compression_ratio = len(my_latent) / len(compressed_data)
compression_ratios.append(compression_ratio)
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1
total_sequences += 1
raw = self.all_data
comp = self.compress(raw)
avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs)
@ -347,6 +341,17 @@ class SpikeRunner(Slate_Runner):
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.")
def compress(raw):
threads = unfuckify_all(raw)
for thread in threads:
# 1. featExtr
# 2. latentProj
# 3. middleOut
# 4. predictor
# 5. calc delta
# 6. encode
# 7. return
if __name__ == '__main__':
print('Initializing...')
slate = Slate({'spikey': SpikeRunner})

112
models.py
View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.fft as fft
import pywt
def get_activation(name):
activations = {
@ -13,11 +14,77 @@ def get_activation(name):
}
return activations[name]()
class FeatureExtractor(nn.Module):
def __init__(self, input_size, transforms):
super(FeatureExtractor, self).__init__()
self.input_size = input_size
self.transforms = self.build_transforms(transforms)
def build_transforms(self, config):
transforms = []
for item in config:
transform_type = item['type']
length = item.get('length', self.input_size)
if length in [None, -1]:
length = self.input_size
if transform_type == 'identity':
transforms.append(('identity', length))
elif transform_type == 'fourier':
transforms.append(('fourier', length))
elif transform_type == 'wavelet':
wavelet_type = item['wavelet_type']
transforms.append(('wavelet', wavelet_type, length))
return transforms
def forward(self, x):
batch_1, batch_2, timesteps = x.size()
x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing
outputs = []
for transform in self.transforms:
if transform[0] == 'identity':
_, length = transform
outputs.append(x[:, -length:])
elif transform[0] == 'fourier':
_, length = transform
fourier_transform = fft.fft(x[:, -length:], dim=1)
fourier_real = fourier_transform.real
fourier_imag = fourier_transform.imag
outputs.append(fourier_real)
outputs.append(fourier_imag)
elif transform[0] == 'wavelet':
_, wavelet_type, length = transform
coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type)
wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs]
wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1)
outputs.append(wavelet_coeffs)
concatenated_outputs = torch.cat(outputs, dim=1)
concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions
return concatenated_outputs
def compute_output_size(self):
size = 0
for transform in self.transforms:
if transform[0] == 'identity':
_, length = transform
size += length
elif transform[0] == 'fourier':
_, length = transform
size += length * 2 # Fourier transform outputs both real and imaginary parts
elif transform[0] == 'wavelet':
_, wavelet_type, length = transform
# Find the true size of the wavelet coefficients
test_signal = torch.zeros(length)
coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type)
wavelet_size = sum(len(c) for c in coeffs)
size += wavelet_size
return size
class LatentFCProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations):
def __init__(self, feature_size, latent_size, layer_shapes, activations):
super(LatentFCProjector, self).__init__()
layers = []
in_features = input_size
in_features = feature_size
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
@ -31,9 +98,9 @@ class LatentFCProjector(nn.Module):
return self.fc(x)
class LatentRNNProjector(nn.Module):
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size):
super(LatentRNNProjector, self).__init__()
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
self.fc = nn.Linear(rnn_hidden_size, latent_size)
self.latent_size = latent_size
@ -43,43 +110,6 @@ class LatentRNNProjector(nn.Module):
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
return latent
class FourierTransformLayer(nn.Module):
def forward(self, x):
x_fft = fft.rfft(x, dim=-1)
return x_fft
class LatentFourierProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
super(LatentFourierProjector, self).__init__()
self.fourier_transform = FourierTransformLayer()
layers = []
if pass_raw_len is None:
pass_raw_len = input_size
else:
assert pass_raw_len <= input_size
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
self.pass_raw_len = pass_raw_len
def forward(self, x):
batch_1, batch_2, timesteps = x.size()
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
latent = self.fc(combined_input)
latent = latent.view(batch_1, batch_2, self.latent_size)
return latent
class MiddleOut(nn.Module):
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
super(MiddleOut, self).__init__()