Feature Extraction
This commit is contained in:
parent
d9d0e2b8c8
commit
eefacf884d
33
config.yaml
33
config.yaml
@ -1,5 +1,38 @@
|
||||
name: EXAMPLE
|
||||
|
||||
feature_extractor:
|
||||
- type: 'identity' # Pass the last n samples of the input data directly.
|
||||
length: 8 # Number of last samples to pass directly. Use full input size if set to null.
|
||||
- type: 'fourier' # Apply Fourier transform to the input data.
|
||||
length: null # Use full input size if set to null. Fourier transform outputs both real and imaginary parts, doubling the size. (Computationally expensive)
|
||||
- type: 'wavelet' # Apply selected wavelet transform to the input data.
|
||||
wavelet_type: 'haar' # Haar wavelet is simple and fast, but may not capture detailed features well.
|
||||
length: null # Use full input size if set to null.
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'cgau1' # Complex Gaussian wavelets are used for complex-valued signal analysis and capturing phase information.
|
||||
length: null # Use full input size if set to null.
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'db1' # Daubechies wavelets provide a balance between time and frequency localization.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'sym2' # Symlet wavelets are nearly symmetrical, offering improved phase characteristics over Daubechies.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'coif1' # Coiflet wavelets have more vanishing moments, suitable for capturing polynomial trends.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'bior1.3' # Biorthogonal wavelets provide perfect reconstruction and linear phase characteristics.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'rbio1.3' # Reverse Biorthogonal wavelets are similar to Biorthogonal but optimized for different applications.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'dmey' # Discrete Meyer wavelets offer good frequency localization, ideal for signals with oscillatory components.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
- type: 'wavelet'
|
||||
wavelet_type: 'morl' # Morlet wavelets are useful for time-frequency analysis due to their Gaussian-modulated sinusoid shape.
|
||||
length: null # Use full input size if set to null. (Computationally expensive)
|
||||
|
||||
latent_projector:
|
||||
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
|
||||
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)
|
||||
|
39
main.py
39
main.py
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
import random, math
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
||||
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
|
||||
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
||||
import wandb
|
||||
from pycallgraph2 import PyCallGraph
|
||||
@ -26,10 +26,10 @@ class SpikeRunner(Slate_Runner):
|
||||
data_url = slate.consume(data_config, 'url')
|
||||
cut_length = slate.consume(data_config, 'cut_length', None)
|
||||
download_and_extract_data(data_url)
|
||||
all_data = load_all_wavs('data', cut_length)
|
||||
self.all_data = load_all_wavs('data', cut_length)
|
||||
|
||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
||||
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
||||
self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
|
||||
|
||||
print("Reconstructing thread topology")
|
||||
self.topology_matrix = compute_topology_metrics(self.train_data)
|
||||
@ -52,12 +52,13 @@ class SpikeRunner(Slate_Runner):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
|
||||
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
||||
feature_size = self.feat.compute_output_size()
|
||||
|
||||
if latent_projector_type == 'fc':
|
||||
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
elif latent_projector_type == 'rnn':
|
||||
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
elif latent_projector_type == 'fourier':
|
||||
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||
else:
|
||||
raise Exception('No such Latent Projector')
|
||||
|
||||
@ -284,15 +285,8 @@ class SpikeRunner(Slate_Runner):
|
||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||
|
||||
if self.full_compression:
|
||||
self.encoder.build_model(my_latent.cpu().numpy())
|
||||
compressed_data = self.encoder.encode(my_latent.cpu().numpy())
|
||||
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
|
||||
compression_ratio = len(my_latent) / len(compressed_data)
|
||||
compression_ratios.append(compression_ratio)
|
||||
|
||||
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
|
||||
exact_matches += 1
|
||||
total_sequences += 1
|
||||
raw = self.all_data
|
||||
comp = self.compress(raw)
|
||||
|
||||
avg_loss = total_loss / len(self.test_data)
|
||||
tot_err = sum(errs) / len(errs)
|
||||
@ -347,6 +341,17 @@ class SpikeRunner(Slate_Runner):
|
||||
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
|
||||
print(f"New high score! Models saved at epoch {epoch+1}.")
|
||||
|
||||
def compress(raw):
|
||||
threads = unfuckify_all(raw)
|
||||
for thread in threads:
|
||||
# 1. featExtr
|
||||
# 2. latentProj
|
||||
# 3. middleOut
|
||||
# 4. predictor
|
||||
# 5. calc delta
|
||||
# 6. encode
|
||||
# 7. return
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Initializing...')
|
||||
slate = Slate({'spikey': SpikeRunner})
|
||||
|
112
models.py
112
models.py
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.fft as fft
|
||||
import pywt
|
||||
|
||||
def get_activation(name):
|
||||
activations = {
|
||||
@ -13,11 +14,77 @@ def get_activation(name):
|
||||
}
|
||||
return activations[name]()
|
||||
|
||||
class FeatureExtractor(nn.Module):
|
||||
def __init__(self, input_size, transforms):
|
||||
super(FeatureExtractor, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.transforms = self.build_transforms(transforms)
|
||||
|
||||
def build_transforms(self, config):
|
||||
transforms = []
|
||||
for item in config:
|
||||
transform_type = item['type']
|
||||
length = item.get('length', self.input_size)
|
||||
if length in [None, -1]:
|
||||
length = self.input_size
|
||||
|
||||
if transform_type == 'identity':
|
||||
transforms.append(('identity', length))
|
||||
elif transform_type == 'fourier':
|
||||
transforms.append(('fourier', length))
|
||||
elif transform_type == 'wavelet':
|
||||
wavelet_type = item['wavelet_type']
|
||||
transforms.append(('wavelet', wavelet_type, length))
|
||||
return transforms
|
||||
|
||||
def forward(self, x):
|
||||
batch_1, batch_2, timesteps = x.size()
|
||||
x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing
|
||||
outputs = []
|
||||
for transform in self.transforms:
|
||||
if transform[0] == 'identity':
|
||||
_, length = transform
|
||||
outputs.append(x[:, -length:])
|
||||
elif transform[0] == 'fourier':
|
||||
_, length = transform
|
||||
fourier_transform = fft.fft(x[:, -length:], dim=1)
|
||||
fourier_real = fourier_transform.real
|
||||
fourier_imag = fourier_transform.imag
|
||||
outputs.append(fourier_real)
|
||||
outputs.append(fourier_imag)
|
||||
elif transform[0] == 'wavelet':
|
||||
_, wavelet_type, length = transform
|
||||
coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type)
|
||||
wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs]
|
||||
wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1)
|
||||
outputs.append(wavelet_coeffs)
|
||||
concatenated_outputs = torch.cat(outputs, dim=1)
|
||||
concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions
|
||||
return concatenated_outputs
|
||||
|
||||
def compute_output_size(self):
|
||||
size = 0
|
||||
for transform in self.transforms:
|
||||
if transform[0] == 'identity':
|
||||
_, length = transform
|
||||
size += length
|
||||
elif transform[0] == 'fourier':
|
||||
_, length = transform
|
||||
size += length * 2 # Fourier transform outputs both real and imaginary parts
|
||||
elif transform[0] == 'wavelet':
|
||||
_, wavelet_type, length = transform
|
||||
# Find the true size of the wavelet coefficients
|
||||
test_signal = torch.zeros(length)
|
||||
coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type)
|
||||
wavelet_size = sum(len(c) for c in coeffs)
|
||||
size += wavelet_size
|
||||
return size
|
||||
|
||||
class LatentFCProjector(nn.Module):
|
||||
def __init__(self, input_size, latent_size, layer_shapes, activations):
|
||||
def __init__(self, feature_size, latent_size, layer_shapes, activations):
|
||||
super(LatentFCProjector, self).__init__()
|
||||
layers = []
|
||||
in_features = input_size
|
||||
in_features = feature_size
|
||||
for i, out_features in enumerate(layer_shapes):
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
if activations[i] != 'None':
|
||||
@ -31,9 +98,9 @@ class LatentFCProjector(nn.Module):
|
||||
return self.fc(x)
|
||||
|
||||
class LatentRNNProjector(nn.Module):
|
||||
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
||||
def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
||||
super(LatentRNNProjector, self).__init__()
|
||||
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
||||
self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
||||
self.fc = nn.Linear(rnn_hidden_size, latent_size)
|
||||
self.latent_size = latent_size
|
||||
|
||||
@ -43,43 +110,6 @@ class LatentRNNProjector(nn.Module):
|
||||
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
||||
return latent
|
||||
|
||||
class FourierTransformLayer(nn.Module):
|
||||
def forward(self, x):
|
||||
x_fft = fft.rfft(x, dim=-1)
|
||||
return x_fft
|
||||
|
||||
class LatentFourierProjector(nn.Module):
|
||||
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
|
||||
super(LatentFourierProjector, self).__init__()
|
||||
self.fourier_transform = FourierTransformLayer()
|
||||
layers = []
|
||||
|
||||
if pass_raw_len is None:
|
||||
pass_raw_len = input_size
|
||||
else:
|
||||
assert pass_raw_len <= input_size
|
||||
|
||||
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
||||
for i, out_features in enumerate(layer_shapes):
|
||||
layers.append(nn.Linear(in_features, out_features))
|
||||
if activations[i] != 'None':
|
||||
layers.append(get_activation(activations[i]))
|
||||
in_features = out_features
|
||||
|
||||
layers.append(nn.Linear(in_features, latent_size))
|
||||
self.fc = nn.Sequential(*layers)
|
||||
self.latent_size = latent_size
|
||||
self.pass_raw_len = pass_raw_len
|
||||
|
||||
def forward(self, x):
|
||||
batch_1, batch_2, timesteps = x.size()
|
||||
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
|
||||
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
||||
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
||||
latent = self.fc(combined_input)
|
||||
latent = latent.view(batch_1, batch_2, self.latent_size)
|
||||
return latent
|
||||
|
||||
class MiddleOut(nn.Module):
|
||||
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
||||
super(MiddleOut, self).__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user