Feature Extraction
This commit is contained in:
parent
d9d0e2b8c8
commit
eefacf884d
33
config.yaml
33
config.yaml
@ -1,5 +1,38 @@
|
|||||||
name: EXAMPLE
|
name: EXAMPLE
|
||||||
|
|
||||||
|
feature_extractor:
|
||||||
|
- type: 'identity' # Pass the last n samples of the input data directly.
|
||||||
|
length: 8 # Number of last samples to pass directly. Use full input size if set to null.
|
||||||
|
- type: 'fourier' # Apply Fourier transform to the input data.
|
||||||
|
length: null # Use full input size if set to null. Fourier transform outputs both real and imaginary parts, doubling the size. (Computationally expensive)
|
||||||
|
- type: 'wavelet' # Apply selected wavelet transform to the input data.
|
||||||
|
wavelet_type: 'haar' # Haar wavelet is simple and fast, but may not capture detailed features well.
|
||||||
|
length: null # Use full input size if set to null.
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'cgau1' # Complex Gaussian wavelets are used for complex-valued signal analysis and capturing phase information.
|
||||||
|
length: null # Use full input size if set to null.
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'db1' # Daubechies wavelets provide a balance between time and frequency localization.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'sym2' # Symlet wavelets are nearly symmetrical, offering improved phase characteristics over Daubechies.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'coif1' # Coiflet wavelets have more vanishing moments, suitable for capturing polynomial trends.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'bior1.3' # Biorthogonal wavelets provide perfect reconstruction and linear phase characteristics.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'rbio1.3' # Reverse Biorthogonal wavelets are similar to Biorthogonal but optimized for different applications.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'dmey' # Discrete Meyer wavelets offer good frequency localization, ideal for signals with oscillatory components.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
- type: 'wavelet'
|
||||||
|
wavelet_type: 'morl' # Morlet wavelets are useful for time-frequency analysis due to their Gaussian-modulated sinusoid shape.
|
||||||
|
length: null # Use full input size if set to null. (Computationally expensive)
|
||||||
|
|
||||||
latent_projector:
|
latent_projector:
|
||||||
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
|
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
|
||||||
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)
|
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)
|
||||||
|
39
main.py
39
main.py
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random, math
|
import random, math
|
||||||
from utils import visualize_prediction, plot_delta_distribution
|
from utils import visualize_prediction, plot_delta_distribution
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
|
||||||
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
|
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
|
||||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
||||||
import wandb
|
import wandb
|
||||||
from pycallgraph2 import PyCallGraph
|
from pycallgraph2 import PyCallGraph
|
||||||
@ -26,10 +26,10 @@ class SpikeRunner(Slate_Runner):
|
|||||||
data_url = slate.consume(data_config, 'url')
|
data_url = slate.consume(data_config, 'url')
|
||||||
cut_length = slate.consume(data_config, 'cut_length', None)
|
cut_length = slate.consume(data_config, 'cut_length', None)
|
||||||
download_and_extract_data(data_url)
|
download_and_extract_data(data_url)
|
||||||
all_data = load_all_wavs('data', cut_length)
|
self.all_data = load_all_wavs('data', cut_length)
|
||||||
|
|
||||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
||||||
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
|
||||||
|
|
||||||
print("Reconstructing thread topology")
|
print("Reconstructing thread topology")
|
||||||
self.topology_matrix = compute_topology_metrics(self.train_data)
|
self.topology_matrix = compute_topology_metrics(self.train_data)
|
||||||
@ -52,12 +52,13 @@ class SpikeRunner(Slate_Runner):
|
|||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
||||||
|
feature_size = self.feat.compute_output_size()
|
||||||
|
|
||||||
if latent_projector_type == 'fc':
|
if latent_projector_type == 'fc':
|
||||||
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
elif latent_projector_type == 'rnn':
|
elif latent_projector_type == 'rnn':
|
||||||
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
elif latent_projector_type == 'fourier':
|
|
||||||
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
|
||||||
else:
|
else:
|
||||||
raise Exception('No such Latent Projector')
|
raise Exception('No such Latent Projector')
|
||||||
|
|
||||||
@ -284,15 +285,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||||
|
|
||||||
if self.full_compression:
|
if self.full_compression:
|
||||||
self.encoder.build_model(my_latent.cpu().numpy())
|
raw = self.all_data
|
||||||
compressed_data = self.encoder.encode(my_latent.cpu().numpy())
|
comp = self.compress(raw)
|
||||||
decompressed_data = self.encoder.decode(compressed_data, len(my_latent))
|
|
||||||
compression_ratio = len(my_latent) / len(compressed_data)
|
|
||||||
compression_ratios.append(compression_ratio)
|
|
||||||
|
|
||||||
if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5):
|
|
||||||
exact_matches += 1
|
|
||||||
total_sequences += 1
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(self.test_data)
|
avg_loss = total_loss / len(self.test_data)
|
||||||
tot_err = sum(errs) / len(errs)
|
tot_err = sum(errs) / len(errs)
|
||||||
@ -347,6 +341,17 @@ class SpikeRunner(Slate_Runner):
|
|||||||
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
|
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
|
||||||
print(f"New high score! Models saved at epoch {epoch+1}.")
|
print(f"New high score! Models saved at epoch {epoch+1}.")
|
||||||
|
|
||||||
|
def compress(raw):
|
||||||
|
threads = unfuckify_all(raw)
|
||||||
|
for thread in threads:
|
||||||
|
# 1. featExtr
|
||||||
|
# 2. latentProj
|
||||||
|
# 3. middleOut
|
||||||
|
# 4. predictor
|
||||||
|
# 5. calc delta
|
||||||
|
# 6. encode
|
||||||
|
# 7. return
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print('Initializing...')
|
print('Initializing...')
|
||||||
slate = Slate({'spikey': SpikeRunner})
|
slate = Slate({'spikey': SpikeRunner})
|
||||||
|
112
models.py
112
models.py
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.fft as fft
|
import torch.fft as fft
|
||||||
|
import pywt
|
||||||
|
|
||||||
def get_activation(name):
|
def get_activation(name):
|
||||||
activations = {
|
activations = {
|
||||||
@ -13,11 +14,77 @@ def get_activation(name):
|
|||||||
}
|
}
|
||||||
return activations[name]()
|
return activations[name]()
|
||||||
|
|
||||||
|
class FeatureExtractor(nn.Module):
|
||||||
|
def __init__(self, input_size, transforms):
|
||||||
|
super(FeatureExtractor, self).__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.transforms = self.build_transforms(transforms)
|
||||||
|
|
||||||
|
def build_transforms(self, config):
|
||||||
|
transforms = []
|
||||||
|
for item in config:
|
||||||
|
transform_type = item['type']
|
||||||
|
length = item.get('length', self.input_size)
|
||||||
|
if length in [None, -1]:
|
||||||
|
length = self.input_size
|
||||||
|
|
||||||
|
if transform_type == 'identity':
|
||||||
|
transforms.append(('identity', length))
|
||||||
|
elif transform_type == 'fourier':
|
||||||
|
transforms.append(('fourier', length))
|
||||||
|
elif transform_type == 'wavelet':
|
||||||
|
wavelet_type = item['wavelet_type']
|
||||||
|
transforms.append(('wavelet', wavelet_type, length))
|
||||||
|
return transforms
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_1, batch_2, timesteps = x.size()
|
||||||
|
x = x.view(batch_1 * batch_2, timesteps) # Combine batch dimensions for processing
|
||||||
|
outputs = []
|
||||||
|
for transform in self.transforms:
|
||||||
|
if transform[0] == 'identity':
|
||||||
|
_, length = transform
|
||||||
|
outputs.append(x[:, -length:])
|
||||||
|
elif transform[0] == 'fourier':
|
||||||
|
_, length = transform
|
||||||
|
fourier_transform = fft.fft(x[:, -length:], dim=1)
|
||||||
|
fourier_real = fourier_transform.real
|
||||||
|
fourier_imag = fourier_transform.imag
|
||||||
|
outputs.append(fourier_real)
|
||||||
|
outputs.append(fourier_imag)
|
||||||
|
elif transform[0] == 'wavelet':
|
||||||
|
_, wavelet_type, length = transform
|
||||||
|
coeffs = pywt.wavedec(x[:, -length:].cpu().numpy(), wavelet_type)
|
||||||
|
wavelet_coeffs = [torch.tensor(coeff, dtype=torch.float32, device=x.device) for coeff in coeffs]
|
||||||
|
wavelet_coeffs = torch.cat(wavelet_coeffs, dim=1)
|
||||||
|
outputs.append(wavelet_coeffs)
|
||||||
|
concatenated_outputs = torch.cat(outputs, dim=1)
|
||||||
|
concatenated_outputs = concatenated_outputs.view(batch_1, batch_2, -1) # Reshape back to original batch dimensions
|
||||||
|
return concatenated_outputs
|
||||||
|
|
||||||
|
def compute_output_size(self):
|
||||||
|
size = 0
|
||||||
|
for transform in self.transforms:
|
||||||
|
if transform[0] == 'identity':
|
||||||
|
_, length = transform
|
||||||
|
size += length
|
||||||
|
elif transform[0] == 'fourier':
|
||||||
|
_, length = transform
|
||||||
|
size += length * 2 # Fourier transform outputs both real and imaginary parts
|
||||||
|
elif transform[0] == 'wavelet':
|
||||||
|
_, wavelet_type, length = transform
|
||||||
|
# Find the true size of the wavelet coefficients
|
||||||
|
test_signal = torch.zeros(length)
|
||||||
|
coeffs = pywt.wavedec(test_signal.numpy(), wavelet_type)
|
||||||
|
wavelet_size = sum(len(c) for c in coeffs)
|
||||||
|
size += wavelet_size
|
||||||
|
return size
|
||||||
|
|
||||||
class LatentFCProjector(nn.Module):
|
class LatentFCProjector(nn.Module):
|
||||||
def __init__(self, input_size, latent_size, layer_shapes, activations):
|
def __init__(self, feature_size, latent_size, layer_shapes, activations):
|
||||||
super(LatentFCProjector, self).__init__()
|
super(LatentFCProjector, self).__init__()
|
||||||
layers = []
|
layers = []
|
||||||
in_features = input_size
|
in_features = feature_size
|
||||||
for i, out_features in enumerate(layer_shapes):
|
for i, out_features in enumerate(layer_shapes):
|
||||||
layers.append(nn.Linear(in_features, out_features))
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
if activations[i] != 'None':
|
if activations[i] != 'None':
|
||||||
@ -31,9 +98,9 @@ class LatentFCProjector(nn.Module):
|
|||||||
return self.fc(x)
|
return self.fc(x)
|
||||||
|
|
||||||
class LatentRNNProjector(nn.Module):
|
class LatentRNNProjector(nn.Module):
|
||||||
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
def __init__(self, feature_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
||||||
super(LatentRNNProjector, self).__init__()
|
super(LatentRNNProjector, self).__init__()
|
||||||
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
self.rnn = nn.LSTM(feature_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
||||||
self.fc = nn.Linear(rnn_hidden_size, latent_size)
|
self.fc = nn.Linear(rnn_hidden_size, latent_size)
|
||||||
self.latent_size = latent_size
|
self.latent_size = latent_size
|
||||||
|
|
||||||
@ -43,43 +110,6 @@ class LatentRNNProjector(nn.Module):
|
|||||||
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
latent = self.fc(out).view(batch_1, batch_2, self.latent_size)
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
class FourierTransformLayer(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x_fft = fft.rfft(x, dim=-1)
|
|
||||||
return x_fft
|
|
||||||
|
|
||||||
class LatentFourierProjector(nn.Module):
|
|
||||||
def __init__(self, input_size, latent_size, layer_shapes, activations, pass_raw_len=None):
|
|
||||||
super(LatentFourierProjector, self).__init__()
|
|
||||||
self.fourier_transform = FourierTransformLayer()
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
if pass_raw_len is None:
|
|
||||||
pass_raw_len = input_size
|
|
||||||
else:
|
|
||||||
assert pass_raw_len <= input_size
|
|
||||||
|
|
||||||
in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts
|
|
||||||
for i, out_features in enumerate(layer_shapes):
|
|
||||||
layers.append(nn.Linear(in_features, out_features))
|
|
||||||
if activations[i] != 'None':
|
|
||||||
layers.append(get_activation(activations[i]))
|
|
||||||
in_features = out_features
|
|
||||||
|
|
||||||
layers.append(nn.Linear(in_features, latent_size))
|
|
||||||
self.fc = nn.Sequential(*layers)
|
|
||||||
self.latent_size = latent_size
|
|
||||||
self.pass_raw_len = pass_raw_len
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batch_1, batch_2, timesteps = x.size()
|
|
||||||
x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps))
|
|
||||||
x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1)
|
|
||||||
combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1)
|
|
||||||
latent = self.fc(combined_input)
|
|
||||||
latent = latent.view(batch_1, batch_2, self.latent_size)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
class MiddleOut(nn.Module):
|
class MiddleOut(nn.Module):
|
||||||
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
||||||
super(MiddleOut, self).__init__()
|
super(MiddleOut, self).__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user