From eefacf884dc0487e367b26b76b49904d5e662491 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 27 May 2024 17:00:02 +0200 Subject: [PATCH] Feature Extraction --- config.yaml | 33 ++++++++++++++++ main.py | 39 ++++++++++-------- models.py | 112 +++++++++++++++++++++++++++++++++------------------- 3 files changed, 126 insertions(+), 58 deletions(-) diff --git a/config.yaml b/config.yaml index 2121aa5..18b4224 100644 --- a/config.yaml +++ b/config.yaml @@ -1,5 +1,38 @@ name: EXAMPLE +feature_extractor: + - type: 'identity' # Pass the last n samples of the input data directly. + length: 8 # Number of last samples to pass directly. Use full input size if set to null. + - type: 'fourier' # Apply Fourier transform to the input data. + length: null # Use full input size if set to null. Fourier transform outputs both real and imaginary parts, doubling the size. (Computationally expensive) + - type: 'wavelet' # Apply selected wavelet transform to the input data. + wavelet_type: 'haar' # Haar wavelet is simple and fast, but may not capture detailed features well. + length: null # Use full input size if set to null. + - type: 'wavelet' + wavelet_type: 'cgau1' # Complex Gaussian wavelets are used for complex-valued signal analysis and capturing phase information. + length: null # Use full input size if set to null. + - type: 'wavelet' + wavelet_type: 'db1' # Daubechies wavelets provide a balance between time and frequency localization. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'sym2' # Symlet wavelets are nearly symmetrical, offering improved phase characteristics over Daubechies. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'coif1' # Coiflet wavelets have more vanishing moments, suitable for capturing polynomial trends. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'bior1.3' # Biorthogonal wavelets provide perfect reconstruction and linear phase characteristics. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'rbio1.3' # Reverse Biorthogonal wavelets are similar to Biorthogonal but optimized for different applications. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'dmey' # Discrete Meyer wavelets offer good frequency localization, ideal for signals with oscillatory components. + length: null # Use full input size if set to null. (Computationally expensive) + - type: 'wavelet' + wavelet_type: 'morl' # Morlet wavelets are useful for time-frequency analysis due to their Gaussian-modulated sinusoid shape. + length: null # Use full input size if set to null. (Computationally expensive) + latent_projector: type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier' input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s) diff --git a/main.py b/main.py index 766ae98..3e3552b 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,8 @@ import torch.nn as nn import numpy as np import random, math from utils import visualize_prediction, plot_delta_distribution -from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics -from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor +from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all +from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder import wandb from pycallgraph2 import PyCallGraph @@ -26,10 +26,10 @@ class SpikeRunner(Slate_Runner): data_url = slate.consume(data_config, 'url') cut_length = slate.consume(data_config, 'cut_length', None) download_and_extract_data(data_url) - all_data = load_all_wavs('data', cut_length) + self.all_data = load_all_wavs('data', cut_length) split_ratio = slate.consume(data_config, 'split_ratio', 0.5) - self.train_data, self.test_data = split_data_by_time(all_data, split_ratio) + self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio) print("Reconstructing thread topology") self.topology_matrix = compute_topology_metrics(self.train_data) @@ -52,12 +52,13 @@ class SpikeRunner(Slate_Runner): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device + self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device) + feature_size = self.feat.compute_output_size() + if latent_projector_type == 'fc': - self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) + self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) elif latent_projector_type == 'rnn': - self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) - elif latent_projector_type == 'fourier': - self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) + self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) else: raise Exception('No such Latent Projector') @@ -284,15 +285,8 @@ class SpikeRunner(Slate_Runner): all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) if self.full_compression: - self.encoder.build_model(my_latent.cpu().numpy()) - compressed_data = self.encoder.encode(my_latent.cpu().numpy()) - decompressed_data = self.encoder.decode(compressed_data, len(my_latent)) - compression_ratio = len(my_latent) / len(compressed_data) - compression_ratios.append(compression_ratio) - - if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5): - exact_matches += 1 - total_sequences += 1 + raw = self.all_data + comp = self.compress(raw) avg_loss = total_loss / len(self.test_data) tot_err = sum(errs) / len(errs) @@ -347,6 +341,17 @@ class SpikeRunner(Slate_Runner): torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt")) print(f"New high score! Models saved at epoch {epoch+1}.") + def compress(raw): + threads = unfuckify_all(raw) + for thread in threads: + # 1. featExtr + # 2. latentProj + # 3. middleOut + # 4. predictor + # 5. calc delta + # 6. encode + # 7. return + if __name__ == '__main__': print('Initializing...') slate = Slate({'spikey': SpikeRunner}) diff --git a/models.py b/models.py index 8d999f5..8eb7f9b 100644 --- a/models.py +++ b/models.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.fft as fft +import pywt def get_activation(name): activations = { @@ -13,11 +14,77 @@ def get_activation(name): } return activations[name]() +class FeatureExtractor(nn.Module): + def __init__(self, input_size, transforms): + super(FeatureExtractor, self).__init__() + self.input_size = input_size + self.transforms = self.build_transforms(transforms) + + def build_transforms(self, config): + transforms = [] + for item in config: + transform_type = item['type'] + length = item.get('length', self.input_size) + if length in [None, -1]: + length = self.input_size + + if transform_type == 'identity': + transforms.append(('identity', length)) + elif transform_type == 'fourier': + transforms.append(('fourier', length)) + elif transform_type == 'wavelet': + wavelet_type = item['wavelet_type'] + transforms.append(('wavelet', wavelet_type, length)) + return transforms + + def forward(self, x): + batch_1, batch_2, timesteps = x.size() + x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing + outputs = [] + for transform in self.transforms: + if transform[0] == 'identity': + _, length = transform + outputs.append(x[:, -length:]) + elif transform[0] == 'fourier': + _, length = transform + fourier_transform = fft.fft(x[:, -length:], dim=1) + fourier_real = fourier_transform.real + fourier_imag = fourier_transform.imag + outputs.append(fourier_real) + outputs.append(fourier_imag) + elif transform[0] == 'wavelet': + _, wavelet_type, length = transform + coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type) + wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs] + wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1) + outputs.append(wavelet_coeffs) + concatenated_outputs = torch.cat(outputs, dim=1) + concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions + return concatenated_outputs + + def compute_output_size(self): + size = 0 + for transform in self.transforms: + if transform[0] == 'identity': + _, length = transform + size += length + elif transform[0] == 'fourier': + _, length = transform + size += length * 2 # Fourier transform outputs both real and imaginary parts + elif transform[0] == 'wavelet': + _, wavelet_type, length = transform + # Find the true size of the wavelet coefficients + test_signal = torch.zeros(length) + coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type) + wavelet_size = sum(len(c) for c in coeffs) + size += wavelet_size + return size + class LatentFCProjector(nn.Module): - def __init__(self, input_size, latent_size, layer_shapes, activations): + def __init__(self, feature_size, latent_size, layer_shapes, activations): super(LatentFCProjector, self).__init__() layers = [] - in_features = input_size + in_features = feature_size for i, out_features in enumerate(layer_shapes): layers.append(nn.Linear(in_features, out_features)) if activations[i] != 'None': @@ -31,9 +98,9 @@ class LatentFCProjector(nn.Module): return self.fc(x) class LatentRNNProjector(nn.Module): - def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size): + def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size): super(LatentRNNProjector, self).__init__() - self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True) + self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True) self.fc = nn.Linear(rnn_hidden_size, latent_size) self.latent_size = latent_size @@ -43,43 +110,6 @@ class LatentRNNProjector(nn.Module): latent = self.fc(out).view(batch_1, batch_2, self.latent_size) return latent -class FourierTransformLayer(nn.Module): - def forward(self, x): - x_fft = fft.rfft(x, dim=-1) - return x_fft - -class LatentFourierProjector(nn.Module): - def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None): - super(LatentFourierProjector, self).__init__() - self.fourier_transform = FourierTransformLayer() - layers = [] - - if pass_raw_len is None: - pass_raw_len = input_size - else: - assert pass_raw_len <= input_size - - in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts - for i, out_features in enumerate(layer_shapes): - layers.append(nn.Linear(in_features, out_features)) - if activations[i] != 'None': - layers.append(get_activation(activations[i])) - in_features = out_features - - layers.append(nn.Linear(in_features, latent_size)) - self.fc = nn.Sequential(*layers) - self.latent_size = latent_size - self.pass_raw_len = pass_raw_len - - def forward(self, x): - batch_1, batch_2, timesteps = x.size() - x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps)) - x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1) - combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1) - latent = self.fc(combined_input) - latent = latent.view(batch_1, batch_2, self.latent_size) - return latent - class MiddleOut(nn.Module): def __init__(self, latent_size, region_latent_size, num_peers, residual=False): super(MiddleOut, self).__init__()