Changed everything
This commit is contained in:
parent
73b306dc50
commit
29946baff0
2
.gitignore
vendored
2
.gitignore
vendored
@ -8,3 +8,5 @@ job_hist.log
|
|||||||
models
|
models
|
||||||
Xvfb.log
|
Xvfb.log
|
||||||
profiler
|
profiler
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
32
README.md
32
README.md
@ -6,6 +6,24 @@ This repository contains a solution for the [Neuralink Compression Challenge](ht
|
|||||||
|
|
||||||
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
|
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
|
||||||
|
|
||||||
|
## Data Analysis
|
||||||
|
|
||||||
|
The `analysis.ipynb` notebook contains a detailed analysis of the data. We found that there is sometimes significant cross-correlation between the different leads, so we find it vital to use this information for better compression. This cross-correlation allows us to improve the accuracy of our predictions and reduce the overall amount of data that needs to be transmitted. As part of the analysis, we also note that achieving a 200x compression ratio is highly unlikely to be possible and is also nonsensical, a very close reproduction is sufficient.
|
||||||
|
|
||||||
|
## Compression Details
|
||||||
|
|
||||||
|
The solution leverages three neural network models to achieve effective compression:
|
||||||
|
|
||||||
|
1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) based on the configuration.
|
||||||
|
|
||||||
|
2. **MiddleOut (Message Passer)**: For each lead, this module looks up the `n` most correlated leads and uses their latent representations along with their correlation values to generate a new latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, correlation) -> new_latent and then averaging over all new_latent values to get the final representation.
|
||||||
|
|
||||||
|
3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training.
|
||||||
|
|
||||||
|
By accurately predicting the next timestep, the delta (difference) between the actual value and the predicted value is minimized. Small deltas mean that fewer bits are needed to store these values, which are then efficiently encoded using the bitstream encoder.
|
||||||
|
|
||||||
|
The neural networks used in this solution are tiny, making it possible to meet the latency and power requirements if implemented more efficiently.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
To install the necessary dependencies, create a virtual environment and install the requirements:
|
To install the necessary dependencies, create a virtual environment and install the requirements:
|
||||||
@ -18,18 +36,9 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Running the Code
|
|
||||||
|
|
||||||
To train the model and compress/decompress WAV files, use the CLI provided:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin
|
|
||||||
python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav
|
|
||||||
```
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
Requires Slate, which is not currently publicaly avaible. Install via (requires repo access)
|
Requires Slate, which is not currently publicly available. Install via (requires repo access):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
|
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
|
||||||
@ -38,5 +47,6 @@ pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
|
|||||||
To train the model, run:
|
To train the model, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python main.py config.yaml Test
|
python main.py <config_file.yaml> <exp_name>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
1285119
analysis.ipynb
Normal file
1285119
analysis.ipynb
Normal file
File diff suppressed because one or more lines are too long
50
cli.py
50
cli.py
@ -1,50 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
|
|
||||||
from main import SpikeRunner
|
|
||||||
|
|
||||||
def load_config(config_path):
|
|
||||||
with open(config_path, 'r') as file:
|
|
||||||
config = yaml.safe_load(file)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
|
|
||||||
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
|
|
||||||
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
|
|
||||||
parser.add_argument('--input_file', help="Path to the input WAV file")
|
|
||||||
parser.add_argument('--output_file', help="Path to the output file")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = load_config(args.config)
|
|
||||||
|
|
||||||
spike_runner = SpikeRunner(None, config)
|
|
||||||
spike_runner.setup('CLI')
|
|
||||||
|
|
||||||
if args.action == 'compress':
|
|
||||||
data = load_all_wavs(args.input_file)
|
|
||||||
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
|
||||||
data = [delta_encode(d) for d in data]
|
|
||||||
|
|
||||||
spike_runner.encoder.build_model(data)
|
|
||||||
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
|
|
||||||
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
|
|
||||||
|
|
||||||
with open(args.output_file, 'wb') as f:
|
|
||||||
for cd in compressed_data:
|
|
||||||
f.write(bytearray(cd))
|
|
||||||
|
|
||||||
elif args.action == 'decompress':
|
|
||||||
with open(args.input_file, 'rb') as f:
|
|
||||||
compressed_data = list(f.read())
|
|
||||||
|
|
||||||
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
|
|
||||||
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
|
||||||
decoded_data = [delta_decode(dd) for dd in decoded_data]
|
|
||||||
|
|
||||||
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
35
config.yaml
35
config.yaml
@ -41,15 +41,22 @@ wandb:
|
|||||||
name: Test
|
name: Test
|
||||||
import: $
|
import: $
|
||||||
|
|
||||||
preprocessing:
|
latent_projector:
|
||||||
use_delta_encoding: false # Whether to use delta encoding.
|
type: fc # Options: 'fc', 'rnn'
|
||||||
|
input_size: 50 # Input size for the Latent Projector (length of snippets).
|
||||||
|
latent_size: 8 # Size of the latent representation before message passing.
|
||||||
|
layer_shapes: [16, 32] # List of layer sizes for the latent projector (if type is 'fc').
|
||||||
|
activations: ['relu', 'relu'] # Activation functions for the latent projector layers (if type is 'fc').
|
||||||
|
rnn_hidden_size: 32 # Hidden size for the RNN projector (if type is 'rnn').
|
||||||
|
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
|
||||||
|
|
||||||
|
middle_out:
|
||||||
|
output_size: 16 # Size of the latent representation after message passing.
|
||||||
|
num_peers: 3 # Number of most correlated peers to consider.
|
||||||
|
|
||||||
predictor:
|
predictor:
|
||||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
layer_shapes: [32, 16] # List of layer sizes for the predictor.
|
||||||
input_size: 1 # Input size for the LSTM predictor.
|
activations: ['relu', 'relu'] # Activation functions for the predictor layers.
|
||||||
hidden_size: 8 # 16 # Hidden size for the LSTM or Fixed Input NN predictor.
|
|
||||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
|
||||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
|
||||||
|
|
||||||
training:
|
training:
|
||||||
epochs: 128 # Number of training epochs.
|
epochs: 128 # Number of training epochs.
|
||||||
@ -59,6 +66,9 @@ training:
|
|||||||
save_path: models # Directory to save the best model and encoder.
|
save_path: models # Directory to save the best model and encoder.
|
||||||
num_points: 1000 # Number of data points to visualize
|
num_points: 1000 # Number of data points to visualize
|
||||||
|
|
||||||
|
evaluation:
|
||||||
|
full_compression: false # Perform full compression during evaluation
|
||||||
|
|
||||||
bitstream_encoding:
|
bitstream_encoding:
|
||||||
type: identity # Options: 'arithmetic', 'no_compression', 'bzip2'
|
type: identity # Options: 'arithmetic', 'no_compression', 'bzip2'
|
||||||
|
|
||||||
@ -66,14 +76,7 @@ data:
|
|||||||
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
||||||
directory: data # Directory to extract and store the dataset.
|
directory: data # Directory to extract and store the dataset.
|
||||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
||||||
|
cut_length: None # Optional length to cut sequences to.
|
||||||
|
|
||||||
profiler:
|
profiler:
|
||||||
enable: false
|
enable: false
|
||||||
|
|
||||||
ablative:
|
|
||||||
training:
|
|
||||||
learning_rate: [0.01, 0.0001, 0.00001]
|
|
||||||
batch_size: [4, 16]
|
|
||||||
predictor:
|
|
||||||
hidden_size: [4, 16]
|
|
||||||
num_layers: [1, 3]
|
|
@ -14,33 +14,33 @@ def download_and_extract_data(url, data_dir):
|
|||||||
os.remove(zip_path)
|
os.remove(zip_path)
|
||||||
|
|
||||||
def load_wav(file_path):
|
def load_wav(file_path):
|
||||||
"""Load WAV file and return sample rate and data."""
|
|
||||||
sample_rate, data = wavfile.read(file_path)
|
sample_rate, data = wavfile.read(file_path)
|
||||||
return sample_rate, data
|
return sample_rate, data
|
||||||
|
|
||||||
def load_all_wavs(data_dir):
|
def load_all_wavs(data_dir, cut_length=None):
|
||||||
"""Load all WAV files in the given directory."""
|
|
||||||
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
|
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
|
||||||
all_data = []
|
all_data = []
|
||||||
for file_path in wav_files:
|
for file_path in wav_files:
|
||||||
_, data = load_wav(file_path)
|
_, data = load_wav(file_path)
|
||||||
|
if cut_length:
|
||||||
|
data = data[:cut_length]
|
||||||
all_data.append(data)
|
all_data.append(data)
|
||||||
return all_data
|
return all_data
|
||||||
|
|
||||||
def save_wav(file_path, sample_rate, data):
|
def compute_correlation_matrix(data):
|
||||||
"""Save data to a WAV file."""
|
num_leads = len(data)
|
||||||
wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32))
|
corr_matrix = np.zeros((num_leads, num_leads))
|
||||||
|
for i in range(num_leads):
|
||||||
|
for j in range(num_leads):
|
||||||
|
if i != j:
|
||||||
|
corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
|
||||||
|
return corr_matrix
|
||||||
|
|
||||||
def delta_encode(data):
|
def split_data_by_time(data, split_ratio=0.5):
|
||||||
"""Apply delta encoding to the data."""
|
train_data = []
|
||||||
deltas = [data[0]]
|
test_data = []
|
||||||
for i in range(1, len(data)):
|
for lead in data:
|
||||||
deltas.append(data[i] - data[i - 1])
|
split_idx = int(len(lead) * split_ratio)
|
||||||
return np.array(deltas)
|
train_data.append(lead[:split_idx])
|
||||||
|
test_data.append(lead[split_idx:])
|
||||||
def delta_decode(deltas):
|
return train_data, test_data
|
||||||
"""Decode delta encoded data."""
|
|
||||||
data = [deltas[0]]
|
|
||||||
for i in range(1, len(deltas)):
|
|
||||||
data.append(data[-1] + deltas[i])
|
|
||||||
return np.array(data)
|
|
||||||
|
255
main.py
255
main.py
@ -1,77 +1,71 @@
|
|||||||
from slate import Slate, Slate_Runner
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution
|
||||||
|
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
|
||||||
|
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
||||||
|
import wandb
|
||||||
|
from pycallgraph import PyCallGraph
|
||||||
|
from pycallgraph.output import GraphvizOutput
|
||||||
|
import slate
|
||||||
|
|
||||||
from pycallgraph2 import PyCallGraph
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
from pycallgraph2.output import GraphvizOutput
|
|
||||||
|
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
|
class SpikeRunner:
|
||||||
from model import LSTMPredictor, FixedInputNNPredictor
|
def __init__(self, config):
|
||||||
from train import train_model
|
self.config = config
|
||||||
from bitstream import ArithmeticEncoder, IdentityEncoder, Bzip2Encoder
|
self.name = slate.consume(config, 'name', default='Test')
|
||||||
|
|
||||||
class SpikeRunner(Slate_Runner):
|
|
||||||
def setup(self, name):
|
|
||||||
self.name = name
|
|
||||||
slate, config = self.slate, self.config
|
|
||||||
|
|
||||||
# Consume config sections
|
|
||||||
preprocessing_config = slate.consume(config, 'preprocessing', expand=True)
|
|
||||||
predictor_config = slate.consume(config, 'predictor', expand=True)
|
|
||||||
training_config = slate.consume(config, 'training', expand=True)
|
training_config = slate.consume(config, 'training', expand=True)
|
||||||
bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True)
|
|
||||||
data_config = slate.consume(config, 'data', expand=True)
|
data_config = slate.consume(config, 'data', expand=True)
|
||||||
|
|
||||||
# Data setup
|
|
||||||
data_url = slate.consume(data_config, 'url')
|
data_url = slate.consume(data_config, 'url')
|
||||||
data_dir = slate.consume(data_config, 'directory')
|
data_dir = slate.consume(data_config, 'directory')
|
||||||
|
cut_length = slate.consume(data_config, 'cut_length', None)
|
||||||
download_and_extract_data(data_url, data_dir)
|
download_and_extract_data(data_url, data_dir)
|
||||||
all_data = load_all_wavs(data_dir)
|
all_data = load_all_wavs(data_dir, cut_length)
|
||||||
|
|
||||||
|
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
||||||
|
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
||||||
|
|
||||||
|
# Compute correlation matrix
|
||||||
|
self.correlation_matrix = compute_correlation_matrix(self.train_data)
|
||||||
|
|
||||||
|
# Model setup
|
||||||
|
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
|
||||||
|
|
||||||
|
if latent_projector_type == 'fc':
|
||||||
|
self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
|
elif latent_projector_type == 'rnn':
|
||||||
|
self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
|
||||||
|
|
||||||
|
self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device)
|
||||||
|
self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device)
|
||||||
|
|
||||||
|
# Training parameters
|
||||||
self.epochs = slate.consume(training_config, 'epochs')
|
self.epochs = slate.consume(training_config, 'epochs')
|
||||||
self.batch_size = slate.consume(training_config, 'batch_size')
|
self.batch_size = slate.consume(training_config, 'batch_size')
|
||||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||||
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
|
|
||||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||||
self.save_path = slate.consume(training_config, 'save_path', 'models')
|
self.save_path = slate.consume(training_config, 'save_path')
|
||||||
|
|
||||||
if self.use_delta_encoding:
|
# Evaluation parameter
|
||||||
all_data = [delta_encode(d) for d in all_data]
|
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||||
|
|
||||||
# Split data into train and test sets
|
# Bitstream encoding
|
||||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.8)
|
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
|
||||||
split_idx = int(len(all_data) * split_ratio)
|
if bitstream_type == 'identity':
|
||||||
self.train_data = all_data[:split_idx]
|
self.encoder = IdentityEncoder()
|
||||||
self.test_data = all_data[split_idx:]
|
elif bitstream_type == 'arithmetic':
|
||||||
|
self.encoder = ArithmeticEncoder()
|
||||||
# Model setup
|
elif bitstream_type == 'bzip2':
|
||||||
self.model = self.get_model(predictor_config)
|
self.encoder = Bzip2Encoder()
|
||||||
self.encoder = self.get_encoder(bitstream_config)
|
|
||||||
|
|
||||||
def get_model(self, config):
|
# Optimizer
|
||||||
model_type = slate.consume(config, 'type')
|
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
||||||
if model_type == 'lstm':
|
self.criterion = torch.nn.MSELoss()
|
||||||
return LSTMPredictor(
|
|
||||||
input_size=slate.consume(config, 'input_size'),
|
|
||||||
hidden_size=slate.consume(config, 'hidden_size'),
|
|
||||||
num_layers=slate.consume(config, 'num_layers')
|
|
||||||
)
|
|
||||||
elif model_type == 'fixed_input_nn':
|
|
||||||
return FixedInputNNPredictor(
|
|
||||||
input_size=slate.consume(config, 'fixed_input_size'),
|
|
||||||
hidden_size=slate.consume(config, 'hidden_size')
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model type: {model_type}")
|
|
||||||
|
|
||||||
def get_encoder(self, config):
|
|
||||||
encoder_type = slate.consume(config, 'type')
|
|
||||||
if encoder_type == 'arithmetic':
|
|
||||||
return ArithmeticEncoder()
|
|
||||||
elif encoder_type == 'identity':
|
|
||||||
return IdentityEncoder()
|
|
||||||
elif encoder_type == 'bzip2':
|
|
||||||
return Bzip2Encoder()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown encoder type: {encoder_type}")
|
|
||||||
|
|
||||||
def run(self, run, forceNoProfile=False):
|
def run(self, run, forceNoProfile=False):
|
||||||
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
|
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
|
||||||
@ -80,12 +74,147 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.run(run, forceNoProfile=True)
|
self.run(run, forceNoProfile=True)
|
||||||
print('{PROFILER DONE}')
|
print('{PROFILER DONE}')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.train_model()
|
||||||
|
|
||||||
train_model(
|
def train_model(self):
|
||||||
self.model, self.train_data, self.test_data,
|
max_length = max([len(seq) for seq in self.train_data])
|
||||||
self.epochs, self.batch_size, self.learning_rate,
|
print(f"Max sequence length: {max_length}")
|
||||||
self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path
|
|
||||||
)
|
best_test_score = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(self.epochs):
|
||||||
|
total_loss = 0
|
||||||
|
random.shuffle(self.train_data)
|
||||||
|
for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size):
|
||||||
|
batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data])
|
||||||
|
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device)
|
||||||
|
|
||||||
|
batch_loss = 0
|
||||||
|
for lead_idx in range(len(inputs)):
|
||||||
|
lead_data = inputs[lead_idx]
|
||||||
|
latents = self.projector(lead_data)
|
||||||
|
|
||||||
|
for t in range(latents.shape[0]):
|
||||||
|
my_latent = latents[t]
|
||||||
|
|
||||||
|
peer_latents = []
|
||||||
|
peer_correlations = []
|
||||||
|
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
|
||||||
|
peer_latent = latents[t]
|
||||||
|
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
|
||||||
|
peer_latents.append(peer_latent)
|
||||||
|
peer_correlations.append(peer_correlation)
|
||||||
|
|
||||||
|
peer_latents = torch.stack(peer_latents).to(device)
|
||||||
|
peer_correlations = torch.stack(peer_correlations).to(device)
|
||||||
|
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
|
||||||
|
prediction = self.predictor(new_latent)
|
||||||
|
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
|
||||||
|
|
||||||
|
loss = self.criterion(prediction, target)
|
||||||
|
batch_loss += loss.item()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += batch_loss
|
||||||
|
|
||||||
|
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
||||||
|
print(f'Epoch {epoch+1}/{self.epochs}, Loss: {total_loss}')
|
||||||
|
|
||||||
|
if (epoch + 1) % self.eval_freq == 0:
|
||||||
|
test_loss = self.evaluate_model(epoch)
|
||||||
|
if test_loss < best_test_score:
|
||||||
|
best_test_score = test_loss
|
||||||
|
self.save_models(epoch)
|
||||||
|
|
||||||
|
def evaluate_model(self, epoch):
|
||||||
|
self.projector.eval()
|
||||||
|
self.middle_out.eval()
|
||||||
|
self.predictor.eval()
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
all_true = []
|
||||||
|
all_predicted = []
|
||||||
|
all_deltas = []
|
||||||
|
compression_ratios = []
|
||||||
|
exact_matches = 0
|
||||||
|
total_sequences = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for lead_idx in range(len(self.test_data)):
|
||||||
|
lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device)
|
||||||
|
latents = self.projector(lead_data)
|
||||||
|
|
||||||
|
true_data = []
|
||||||
|
predicted_data = []
|
||||||
|
delta_data = []
|
||||||
|
|
||||||
|
for t in range(latents.shape[0]):
|
||||||
|
my_latent = latents[t]
|
||||||
|
|
||||||
|
peer_latents = []
|
||||||
|
peer_correlations = []
|
||||||
|
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
|
||||||
|
peer_latent = latents[t]
|
||||||
|
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
|
||||||
|
peer_latents.append(peer_latent)
|
||||||
|
peer_correlations.append(peer_correlation)
|
||||||
|
|
||||||
|
peer_latents = torch.stack(peer_latents).to(device)
|
||||||
|
peer_correlations = torch.stack(peer_correlations).to(device)
|
||||||
|
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
|
||||||
|
prediction = self.predictor(new_latent)
|
||||||
|
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
|
||||||
|
|
||||||
|
loss = self.criterion(prediction, target)
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
true_data.append(target.cpu().numpy())
|
||||||
|
predicted_data.append(prediction.cpu().numpy())
|
||||||
|
delta_data.append((target - prediction).cpu().numpy())
|
||||||
|
|
||||||
|
all_true.append(true_data)
|
||||||
|
all_predicted.append(predicted_data)
|
||||||
|
all_deltas.append(delta_data)
|
||||||
|
|
||||||
|
if self.full_compression:
|
||||||
|
self.encoder.build_model(latents.cpu().numpy())
|
||||||
|
compressed_data = self.encoder.encode(latents.cpu().numpy())
|
||||||
|
decompressed_data = self.encoder.decode(compressed_data, len(latents))
|
||||||
|
compression_ratio = len(latents) / len(compressed_data)
|
||||||
|
compression_ratios.append(compression_ratio)
|
||||||
|
|
||||||
|
# Check if decompressed data matches the original data
|
||||||
|
if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5):
|
||||||
|
exact_matches += 1
|
||||||
|
total_sequences += 1
|
||||||
|
|
||||||
|
visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch)
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(self.test_data)
|
||||||
|
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
||||||
|
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
|
||||||
|
|
||||||
|
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
|
||||||
|
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||||
|
|
||||||
|
if self.full_compression:
|
||||||
|
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
||||||
|
exact_match_percentage = (exact_matches / total_sequences) * 100
|
||||||
|
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
||||||
|
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
||||||
|
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
||||||
|
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
def save_models(self, epoch):
|
||||||
|
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
|
||||||
|
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
|
||||||
|
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
|
||||||
|
print(f"New high score! Models saved at epoch {epoch+1}.")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
slate = Slate({'spikey': SpikeRunner})
|
slate = Slate({'spikey': SpikeRunner})
|
||||||
|
110
model.py
110
model.py
@ -1,110 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super(BaseModel, self).__init__()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, x):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encode(self, data):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, encoded_data):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class LSTMPredictor(BaseModel):
|
|
||||||
def __init__(self, input_size, hidden_size, num_layers):
|
|
||||||
super(LSTMPredictor, self).__init__()
|
|
||||||
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
|
||||||
self.fc = nn.Linear(hidden_size, 1)
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
|
||||||
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
|
|
||||||
out, _ = self.rnn(x, (h0, c0))
|
|
||||||
out = self.fc(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def encode(self, data):
|
|
||||||
self.eval()
|
|
||||||
encoded_data = []
|
|
||||||
|
|
||||||
context_size = self.hidden_size # Define an appropriate context size
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(len(data) - 1):
|
|
||||||
context = torch.tensor(data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
|
||||||
if context.size(1) == 0: # Handle empty context
|
|
||||||
continue
|
|
||||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
|
||||||
delta = data[i] - prediction
|
|
||||||
encoded_data.append(delta)
|
|
||||||
|
|
||||||
return encoded_data
|
|
||||||
|
|
||||||
def decode(self, encoded_data):
|
|
||||||
self.eval()
|
|
||||||
decoded_data = []
|
|
||||||
|
|
||||||
context_size = self.hidden_size # Define an appropriate context size
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(len(encoded_data)):
|
|
||||||
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
|
|
||||||
if context.size(1) == 0: # Handle empty context
|
|
||||||
continue
|
|
||||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
|
||||||
decoded_data.append(prediction + encoded_data[i])
|
|
||||||
|
|
||||||
return decoded_data
|
|
||||||
|
|
||||||
class FixedInputNNPredictor(BaseModel):
|
|
||||||
def __init__(self, input_size, hidden_size):
|
|
||||||
super(FixedInputNNPredictor, self).__init__()
|
|
||||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.fc2 = nn.Linear(hidden_size, 1)
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def encode(self, data):
|
|
||||||
self.eval()
|
|
||||||
encoded_data = []
|
|
||||||
|
|
||||||
context_size = self.fc1.in_features # Define an appropriate context size
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(len(data) - context_size):
|
|
||||||
context = torch.tensor(data[i:i + context_size]).reshape(1, -1).to(self.device)
|
|
||||||
if context.size(1) == 0: # Handle empty context
|
|
||||||
continue
|
|
||||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
|
||||||
delta = data[i + context_size] - prediction
|
|
||||||
encoded_data.append(delta)
|
|
||||||
|
|
||||||
return encoded_data
|
|
||||||
|
|
||||||
def decode(self, encoded_data):
|
|
||||||
self.eval()
|
|
||||||
decoded_data = []
|
|
||||||
|
|
||||||
context_size = self.fc1.in_features # Define an appropriate context size
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(len(encoded_data)):
|
|
||||||
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1).to(self.device)
|
|
||||||
if context.size(1) == 0: # Handle empty context
|
|
||||||
continue
|
|
||||||
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
|
|
||||||
decoded_data.append(prediction + encoded_data[i])
|
|
||||||
|
|
||||||
return decoded_data
|
|
75
models.py
Normal file
75
models.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def get_activation(name):
|
||||||
|
activations = {
|
||||||
|
'ReLU': nn.ReLU,
|
||||||
|
'Sigmoid': nn.Sigmoid,
|
||||||
|
'Tanh': nn.Tanh,
|
||||||
|
'LeakyReLU': nn.LeakyReLU,
|
||||||
|
'ELU': nn.ELU,
|
||||||
|
'None': nn.Identity
|
||||||
|
}
|
||||||
|
return activations[name]()
|
||||||
|
|
||||||
|
class LatentProjector(nn.Module):
|
||||||
|
def __init__(self, input_size, latent_size, layer_shapes, activations):
|
||||||
|
super(LatentProjector, self).__init__()
|
||||||
|
layers = []
|
||||||
|
in_features = input_size
|
||||||
|
for i, out_features in enumerate(layer_shapes):
|
||||||
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
|
if activations[i] != 'None':
|
||||||
|
layers.append(get_activation(activations[i]))
|
||||||
|
in_features = out_features
|
||||||
|
layers.append(nn.Linear(in_features, latent_size))
|
||||||
|
self.fc = nn.Sequential(*layers)
|
||||||
|
self.latent_size = latent_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fc(x)
|
||||||
|
|
||||||
|
class LatentRNNProjector(nn.Module):
|
||||||
|
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
|
||||||
|
super(LatentRNNProjector, self).__init__()
|
||||||
|
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
|
||||||
|
self.fc = nn.Linear(rnn_hidden_size, latent_size)
|
||||||
|
self.latent_size = latent_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out, _ = self.rnn(x)
|
||||||
|
latent = self.fc(out)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
class MiddleOut(nn.Module):
|
||||||
|
def __init__(self, latent_size, output_size, num_peers):
|
||||||
|
super(MiddleOut, self).__init__()
|
||||||
|
self.num_peers = num_peers
|
||||||
|
self.fc = nn.Linear(latent_size * 2 + 1, output_size)
|
||||||
|
|
||||||
|
def forward(self, my_latent, peer_latents, peer_correlations):
|
||||||
|
new_latents = []
|
||||||
|
for peer_latent, correlation in zip(peer_latents, peer_correlations):
|
||||||
|
combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1)
|
||||||
|
new_latent = self.fc(combined_input)
|
||||||
|
new_latents.append(new_latent)
|
||||||
|
|
||||||
|
new_latents = torch.stack(new_latents)
|
||||||
|
averaged_latent = torch.mean(new_latents, dim=0)
|
||||||
|
return averaged_latent
|
||||||
|
|
||||||
|
class Predictor(nn.Module):
|
||||||
|
def __init__(self, output_size, layer_shapes, activations):
|
||||||
|
super(Predictor, self).__init__()
|
||||||
|
layers = []
|
||||||
|
in_features = output_size
|
||||||
|
for i, out_features in enumerate(layer_shapes):
|
||||||
|
layers.append(nn.Linear(in_features, out_features))
|
||||||
|
if activations[i] != 'None':
|
||||||
|
layers.append(get_activation(activations[i]))
|
||||||
|
in_features = out_features
|
||||||
|
layers.append(nn.Linear(in_features, 1))
|
||||||
|
self.fc = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, latent):
|
||||||
|
return self.fc(latent)
|
121
train.py
121
train.py
@ -1,121 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import wandb
|
|
||||||
import random
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
from data_processing import delta_encode, delta_decode, save_wav
|
|
||||||
from utils import visualize_prediction, plot_delta_distribution
|
|
||||||
from bitstream import ArithmeticEncoder
|
|
||||||
|
|
||||||
def pad_sequence(sequence, max_length):
|
|
||||||
padded_seq = np.zeros((max_length, *sequence.shape[1:]))
|
|
||||||
padded_seq[:sequence.shape[0], ...] = sequence
|
|
||||||
return padded_seq
|
|
||||||
|
|
||||||
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0):
|
|
||||||
compression_ratios = []
|
|
||||||
identical_count = 0
|
|
||||||
all_deltas = []
|
|
||||||
|
|
||||||
for i, file_data in enumerate(data):
|
|
||||||
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(model.device)
|
|
||||||
encoded_data = model.encode(file_data.squeeze(1).cpu().numpy())
|
|
||||||
encoder.build_model(encoded_data)
|
|
||||||
compressed_data = encoder.encode(encoded_data)
|
|
||||||
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
|
|
||||||
|
|
||||||
if use_delta_encoding:
|
|
||||||
decompressed_data = delta_decode(decompressed_data)
|
|
||||||
|
|
||||||
# Ensure the lengths match
|
|
||||||
min_length = min(len(file_data), len(decompressed_data))
|
|
||||||
file_data = file_data[:min_length]
|
|
||||||
decompressed_data = decompressed_data[:min_length]
|
|
||||||
|
|
||||||
identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
|
|
||||||
if identical:
|
|
||||||
identical_count += 1
|
|
||||||
|
|
||||||
compression_ratio = len(file_data) / len(compressed_data)
|
|
||||||
compression_ratios.append(compression_ratio)
|
|
||||||
|
|
||||||
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(model.device)).squeeze(1).detach().cpu().numpy()
|
|
||||||
if use_delta_encoding:
|
|
||||||
predicted_data = delta_decode(predicted_data)
|
|
||||||
|
|
||||||
# Ensure predicted_data is a flat list of floats
|
|
||||||
predicted_data = predicted_data[:min_length]
|
|
||||||
|
|
||||||
delta_data = [file_data[i].item() - predicted_data[i] for i in range(min_length)]
|
|
||||||
all_deltas.extend(delta_data)
|
|
||||||
|
|
||||||
if i == (epoch % len(data)):
|
|
||||||
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, epoch=epoch)
|
|
||||||
|
|
||||||
identical_percentage = (identical_count / len(data)) * 100
|
|
||||||
delta_plot_path = plot_delta_distribution(all_deltas, epoch)
|
|
||||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
|
||||||
|
|
||||||
return compression_ratios, identical_percentage
|
|
||||||
|
|
||||||
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path):
|
|
||||||
wandb.init(project="wav-compression")
|
|
||||||
criterion = nn.MSELoss()
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
||||||
best_test_score = float('inf')
|
|
||||||
|
|
||||||
model.to(model.device)
|
|
||||||
|
|
||||||
max_length = max([len(seq) for seq in train_data])
|
|
||||||
print(f"Max sequence length: {max_length}")
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
|
||||||
total_loss = 0
|
|
||||||
random.shuffle(train_data)
|
|
||||||
for i in range(0, len(train_data) - batch_size, batch_size):
|
|
||||||
batch_data = [pad_sequence(np.array(train_data[j]), max_length) for j in range(i, i+batch_size)]
|
|
||||||
batch_data = np.array(batch_data)
|
|
||||||
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
|
|
||||||
targets = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
|
||||||
print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
|
|
||||||
|
|
||||||
if (epoch + 1) % eval_freq == 0:
|
|
||||||
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch)
|
|
||||||
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch)
|
|
||||||
|
|
||||||
wandb.log({
|
|
||||||
"train_compression_ratio_mean": np.mean(train_compression_ratios),
|
|
||||||
"train_compression_ratio_std": np.std(train_compression_ratios),
|
|
||||||
"train_compression_ratio_min": np.min(train_compression_ratios),
|
|
||||||
"train_compression_ratio_max": np.max(train_compression_ratios),
|
|
||||||
"test_compression_ratio_mean": np.mean(test_compression_ratios),
|
|
||||||
"test_compression_ratio_std": np.std(test_compression_ratios),
|
|
||||||
"test_compression_ratio_min": np.min(test_compression_ratios),
|
|
||||||
"test_compression_ratio_max": np.max(test_compression_ratios),
|
|
||||||
"train_identical_percentage": train_identical_percentage,
|
|
||||||
"test_identical_percentage": test_identical_percentage,
|
|
||||||
}, step=epoch)
|
|
||||||
|
|
||||||
print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
|
|
||||||
print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
|
|
||||||
|
|
||||||
test_score = np.mean(test_compression_ratios)
|
|
||||||
if test_score < best_test_score:
|
|
||||||
best_test_score = test_score
|
|
||||||
model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt")
|
|
||||||
encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl")
|
|
||||||
torch.save(model.state_dict(), model_path)
|
|
||||||
with open(encoder_path, 'wb') as f:
|
|
||||||
pickle.dump(encoder, f)
|
|
||||||
print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}')
|
|
Loading…
Reference in New Issue
Block a user