Feature Extraction

This commit is contained in:
Dominik Moritz Roth 2024-05-27 17:00:02 +02:00
parent d9d0e2b8c8
commit eefacf884d
3 changed files with 126 additions and 58 deletions

View File

@ -1,5 +1,38 @@
name: EXAMPLE name: EXAMPLE
feature_extractor:
- type: 'identity' # Pass the last n samples of the input data directly.
length: 8 # Number of last samples to pass directly. Use full input size if set to null.
- type: 'fourier' # Apply Fourier transform to the input data.
length: null # Use full input size if set to null. Fourier transform outputs both real and imaginary parts, doubling the size. (Computationally expensive)
- type: 'wavelet' # Apply selected wavelet transform to the input data.
wavelet_type: 'haar' # Haar wavelet is simple and fast, but may not capture detailed features well.
length: null # Use full input size if set to null.
- type: 'wavelet'
wavelet_type: 'cgau1' # Complex Gaussian wavelets are used for complex-valued signal analysis and capturing phase information.
length: null # Use full input size if set to null.
- type: 'wavelet'
wavelet_type: 'db1' # Daubechies wavelets provide a balance between time and frequency localization.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'sym2' # Symlet wavelets are nearly symmetrical, offering improved phase characteristics over Daubechies.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'coif1' # Coiflet wavelets have more vanishing moments, suitable for capturing polynomial trends.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'bior1.3' # Biorthogonal wavelets provide perfect reconstruction and linear phase characteristics.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'rbio1.3' # Reverse Biorthogonal wavelets are similar to Biorthogonal but optimized for different applications.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'dmey' # Discrete Meyer wavelets offer good frequency localization, ideal for signals with oscillatory components.
length: null # Use full input size if set to null. (Computationally expensive)
- type: 'wavelet'
wavelet_type: 'morl' # Morlet wavelets are useful for time-frequency analysis due to their Gaussian-modulated sinusoid shape.
length: null # Use full input size if set to null. (Computationally expensive)
latent_projector: latent_projector:
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier' type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s) input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)

39
main.py
View File

@ -4,8 +4,8 @@ import torch.nn as nn
import numpy as np import numpy as np
import random, math import random, math
from utils import visualize_prediction, plot_delta_distribution from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
import wandb import wandb
from pycallgraph2 import PyCallGraph from pycallgraph2 import PyCallGraph
@ -26,10 +26,10 @@ class SpikeRunner(Slate_Runner):
data_url = slate.consume(data_config, 'url') data_url = slate.consume(data_config, 'url')
cut_length = slate.consume(data_config, 'cut_length', None) cut_length = slate.consume(data_config, 'cut_length', None)
download_and_extract_data(data_url) download_and_extract_data(data_url)
all_data = load_all_wavs('data', cut_length) self.all_data = load_all_wavs('data', cut_length)
split_ratio = slate.consume(data_config, 'split_ratio', 0.5) split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio) self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
print("Reconstructing thread topology") print("Reconstructing thread topology")
self.topology_matrix = compute_topology_metrics(self.train_data) self.topology_matrix = compute_topology_metrics(self.train_data)
@ -52,12 +52,13 @@ class SpikeRunner(Slate_Runner):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device self.device = device
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
feature_size = self.feat.compute_output_size()
if latent_projector_type == 'fc': if latent_projector_type == 'fc':
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn': elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'fourier':
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
else: else:
raise Exception('No such Latent Projector') raise Exception('No such Latent Projector')
@ -284,15 +285,8 @@ class SpikeRunner(Slate_Runner):
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
if self.full_compression: if self.full_compression:
self.encoder.build_model(my_latent.cpu().numpy()) raw = self.all_data
compressed_data = self.encoder.encode(my_latent.cpu().numpy()) comp = self.compress(raw)
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
compression_ratio = len(my_latent) / len(compressed_data)
compression_ratios.append(compression_ratio)
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1
total_sequences += 1
avg_loss = total_loss / len(self.test_data) avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs) tot_err = sum(errs) / len(errs)
@ -347,6 +341,17 @@ class SpikeRunner(Slate_Runner):
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt")) torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.") print(f"New high score! Models saved at epoch {epoch+1}.")
def compress(raw):
threads = unfuckify_all(raw)
for thread in threads:
# 1. featExtr
# 2. latentProj
# 3. middleOut
# 4. predictor
# 5. calc delta
# 6. encode
# 7. return
if __name__ == '__main__': if __name__ == '__main__':
print('Initializing...') print('Initializing...')
slate = Slate({'spikey': SpikeRunner}) slate = Slate({'spikey': SpikeRunner})

112
models.py
View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.fft as fft import torch.fft as fft
import pywt
def get_activation(name): def get_activation(name):
activations = { activations = {
@ -13,11 +14,77 @@ def get_activation(name):
} }
return activations[name]() return activations[name]()
class FeatureExtractor(nn.Module):
def __init__(self, input_size, transforms):
super(FeatureExtractor, self).__init__()
self.input_size = input_size
self.transforms = self.build_transforms(transforms)
def build_transforms(self, config):
transforms = []
for item in config:
transform_type = item['type']
length = item.get('length', self.input_size)
if length in [None, -1]:
length = self.input_size
if transform_type == 'identity':
transforms.append(('identity', length))
elif transform_type == 'fourier':
transforms.append(('fourier', length))
elif transform_type == 'wavelet':
wavelet_type = item['wavelet_type']
transforms.append(('wavelet', wavelet_type, length))
return transforms
def forward(self, x):
batch_1, batch_2, timesteps = x.size()
x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing
outputs = []
for transform in self.transforms:
if transform[0] == 'identity':
_, length = transform
outputs.append(x[:, -length:])
elif transform[0] == 'fourier':
_, length = transform
fourier_transform = fft.fft(x[:, -length:], dim=1)
fourier_real = fourier_transform.real
fourier_imag = fourier_transform.imag
outputs.append(fourier_real)
outputs.append(fourier_imag)
elif transform[0] == 'wavelet':
_, wavelet_type, length = transform
coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type)
wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs]
wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1)
outputs.append(wavelet_coeffs)
concatenated_outputs = torch.cat(outputs, dim=1)
concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions
return concatenated_outputs
def compute_output_size(self):
size = 0
for transform in self.transforms:
if transform[0] == 'identity':
_, length = transform
size += length
elif transform[0] == 'fourier':
_, length = transform
size += length * 2 # Fourier transform outputs both real and imaginary parts
elif transform[0] == 'wavelet':
_, wavelet_type, length = transform
# Find the true size of the wavelet coefficients
test_signal = torch.zeros(length)
coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type)
wavelet_size = sum(len(c) for c in coeffs)
size += wavelet_size
return size
class LatentFCProjector(nn.Module): class LatentFCProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations): def __init__(self, feature_size, latent_size, layer_shapes, activations):
super(LatentFCProjector, self).__init__() super(LatentFCProjector, self).__init__()
layers = [] layers = []
in_features = input_size in_features = feature_size
for i, out_features in enumerate(layer_shapes): for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features)) layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None': if activations[i] != 'None':
@ -31,9 +98,9 @@ class LatentFCProjector(nn.Module):
return self.fc(x) return self.fc(x)
class LatentRNNProjector(nn.Module): class LatentRNNProjector(nn.Module):
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size): def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size):
super(LatentRNNProjector, self).__init__() super(LatentRNNProjector, self).__init__()
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True) self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
self.fc = nn.Linear(rnn_hidden_size, latent_size) self.fc = nn.Linear(rnn_hidden_size, latent_size)
self.latent_size = latent_size self.latent_size = latent_size
@ -43,43 +110,6 @@ class LatentRNNProjector(nn.Module):
latent = self.fc(out).view(batch_1, batch_2, self.latent_size) latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
return latent return latent
class FourierTransformLayer(nn.Module):
def forward(self, x):
x_fft = fft.rfft(x, dim=-1)
return x_fft
class LatentFourierProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
super(LatentFourierProjector, self).__init__()
self.fourier_transform = FourierTransformLayer()
layers = []
if pass_raw_len is None:
pass_raw_len = input_size
else:
assert pass_raw_len <= input_size
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
self.pass_raw_len = pass_raw_len
def forward(self, x):
batch_1, batch_2, timesteps = x.size()
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
latent = self.fc(combined_input)
latent = latent.view(batch_1, batch_2, self.latent_size)
return latent
class MiddleOut(nn.Module): class MiddleOut(nn.Module):
def __init__(self, latent_size, region_latent_size, num_peers, residual=False): def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
super(MiddleOut, self).__init__() super(MiddleOut, self).__init__()