initial commit

This commit is contained in:
Dominik Moritz Roth 2024-05-24 22:01:59 +02:00
commit 0c5f888d75
11 changed files with 633 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
data
data.zip
__pycache__
.venv
wandb
slurm_log
job_hist.log
models
Xvfb.log
profiler

76
README.md Normal file
View File

@ -0,0 +1,76 @@
# Spikey
This repository contains a solution for the [Neuralink Compression Challenge](https://content.neuralink.com/compression-challenge/README.html). The challenge involves compressing raw electrode recordings from a Neuralink implant. These recordings are taken from the motor cortex of a non-human primate while playing a video game.
## Challenge Overview
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
## Installation
To install the necessary dependencies, create a virtual environment and install the requirements:
```bash
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
```
## Usage
### Configuration
The configuration for training and evaluation is specified in a YAML file. Below is an example configuration:
```yaml
name: Test
preprocessing:
use_delta_encoding: true # Whether to use delta encoding.
predictor:
type: lstm # Options: 'lstm', 'fixed_input_nn'
input_size: 1 # Input size for the LSTM predictor.
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
num_layers: 2 # Number of layers for the LSTM predictor.
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
training:
epochs: 10 # Number of training epochs.
batch_size: 32 # Batch size for training.
learning_rate: 0.001 # Learning rate for the optimizer.
eval_freq: 2 # Frequency of evaluation during training (in epochs).
save_path: models # Directory to save the best model and encoder.
num_points: 1000 # Number of data points to visualize.
bitstream_encoding:
type: arithmetic # Use arithmetic encoding.
data:
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
directory: data # Directory to extract and store the dataset.
split_ratio: 0.8 # Ratio to split the data into train and test sets.
```
### Running the Code
To train the model and compress/decompress WAV files, use the CLI provided:
```bash
python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin
python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav
```
### Training
Requires Slate, which is not currently publicaly avaible. Install via (requires repo access)
```bash
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
```
To train the model, run:
```bash
python main.py config.yaml Test
```

35
bitstream.py Normal file
View File

@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from arithmetic_compressor import AECompressor
from arithmetic_compressor.models import StaticModel
class BaseEncoder(ABC):
@abstractmethod
def encode(self, data):
pass
@abstractmethod
def decode(self, encoded_data, num_symbols):
pass
@abstractmethod
def build_model(self, data):
pass
class ArithmeticEncoder(BaseEncoder):
def encode(self, data):
if not hasattr(self, 'model'):
raise ValueError("Model not built. Call build_model(data) before encoding.")
coder = AECompressor(self.model)
compressed_data = coder.compress(data)
return compressed_data
def decode(self, encoded_data, num_symbols):
coder = AECompressor(self.model)
decoded_data = coder.decompress(encoded_data, num_symbols)
return decoded_data
def build_model(self, data):
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
total_symbols = sum(symbol_counts.values())
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
self.model = StaticModel(probabilities)

50
cli.py Normal file
View File

@ -0,0 +1,50 @@
import argparse
import yaml
import os
import torch
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
from main import SpikeRunner
def load_config(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
def main():
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
parser.add_argument('--input_file', help="Path to the input WAV file")
parser.add_argument('--output_file', help="Path to the output file")
args = parser.parse_args()
config = load_config(args.config)
spike_runner = SpikeRunner(None, config)
spike_runner.setup('CLI')
if args.action == 'compress':
data = load_all_wavs(args.input_file)
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
data = [delta_encode(d) for d in data]
spike_runner.encoder.build_model(data)
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
with open(args.output_file, 'wb') as f:
for cd in compressed_data:
f.write(bytearray(cd))
elif args.action == 'decompress':
with open(args.input_file, 'rb') as f:
compressed_data = list(f.read())
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
decoded_data = [delta_decode(dd) for dd in decoded_data]
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
if __name__ == "__main__":
main()

67
config.yaml Normal file
View File

@ -0,0 +1,67 @@
name: DEFAULT
project: Spikey
slurm:
name: 'Spikey_{config[name]}'
partitions:
- single
standard_output: ./reports/slurm/out_%A_%a.log
standard_error: ./reports/slurm/err_%A_%a.log
num_parallel_jobs: 50
cpus_per_task: 4
memory_per_cpu: 1000
time_limit: 1440 # in minutes
ntasks: 1
venv: '.venv/bin/activate'
sh_lines:
- 'mkdir -p {tmp}/wandb'
- 'mkdir -p {tmp}/local_pycache'
- 'export PYTHONPYCACHEPREFIX={tmp}/local_pycache'
runner: spikey
scheduler:
reps_per_version: 1
agents_per_job: 1
reps_per_agent: 1
wandb:
project: '{config[project]}'
group: '{config[name]}'
job_type: '{delta_desc}'
name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}'
tags:
- '{config[env][name]}'
- '{config[algo][name]}'
sync_tensorboard: False
monitor_gym: False
save_code: False
---
name: Test
preprocessing:
use_delta_encoding: true # Whether to use delta encoding.
predictor:
type: lstm # Options: 'lstm', 'fixed_input_nn'
input_size: 1 # Input size for the LSTM predictor.
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
num_layers: 2 # Number of layers for the LSTM predictor.
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
training:
epochs: 10 # Number of training epochs.
batch_size: 32 # Batch size for training.
learning_rate: 0.001 # Learning rate for the optimizer.
eval_freq: 2 # Frequency of evaluation during training (in epochs).
save_path: models # Directory to save the best model and encoder.
num_points: 1000 # Number of data points to visualize
bitstream_encoding:
type: arithmetic # Use arithmetic encoding.
data:
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
directory: data # Directory to extract and store the dataset.
split_ratio: 0.8 # Ratio to split the data into train and test sets.

46
data_processing.py Normal file
View File

@ -0,0 +1,46 @@
import os
import numpy as np
from scipy.io import wavfile
import urllib.request
import zipfile
def download_and_extract_data(url, data_dir):
if not os.path.exists(data_dir):
os.makedirs(data_dir)
zip_path = os.path.join(data_dir, 'data.zip')
urllib.request.urlretrieve(url, zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
os.remove(zip_path)
def load_wav(file_path):
"""Load WAV file and return sample rate and data."""
sample_rate, data = wavfile.read(file_path)
return sample_rate, data
def load_all_wavs(data_dir):
"""Load all WAV files in the given directory."""
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
all_data = []
for file_path in wav_files:
_, data = load_wav(file_path)
all_data.append(data)
return all_data
def save_wav(file_path, sample_rate, data):
"""Save data to a WAV file."""
wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32))
def delta_encode(data):
"""Apply delta encoding to the data."""
deltas = [data[0]]
for i in range(1, len(data)):
deltas.append(data[i] - data[i - 1])
return deltas
def delta_decode(deltas):
"""Decode delta encoded data."""
data = [deltas[0]]
for i in range(1, len(deltas)):
data.append(data[-1] + deltas[i])
return data

69
main.py Normal file
View File

@ -0,0 +1,69 @@
import yaml
from slate import Slate, Slate_Runner
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
from model import LSTMPredictor, FixedInputNNPredictor
from train import train_model
from bitstream import ArithmeticEncoder
class SpikeRunner(Slate_Runner):
def setup(self, name):
self.name = name
slate, config = self.slate, self.config
# Consume config sections
preprocessing_config = slate.consume(config, 'preprocessing', expand=True)
predictor_config = slate.consume(config, 'predictor', expand=True)
training_config = slate.consume(config, 'training', expand=True)
bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True)
data_config = slate.consume(config, 'data', expand=True)
# Data setup
data_url = slate.consume(data_config, 'url')
data_dir = slate.consume(data_config, 'directory')
download_and_extract_data(data_url, data_dir)
all_data = load_all_wavs(data_dir)
if slate.consume(preprocessing_config, 'use_delta_encoding'):
all_data = [delta_encode(d) for d in all_data]
# Split data into train and test sets
split_ratio = slate.consume(data_config, 'split_ratio', 0.8)
split_idx = int(len(all_data) * split_ratio)
self.train_data = all_data[:split_idx]
self.test_data = all_data[split_idx:]
# Model setup
self.model = self.get_model(predictor_config)
self.encoder = self.get_encoder(bitstream_config)
self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size')
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path', 'models')
def get_model(self, config):
model_type = self.slate.consume(config, 'type')
if model_type == 'lstm':
return LSTMPredictor(
input_size=self.slate.consume(config, 'input_size'),
hidden_size=self.slate.consume(config, 'hidden_size'),
num_layers=self.slate.consume(config, 'num_layers')
)
elif model_type == 'fixed_input_nn':
return FixedInputNNPredictor(
input_size=self.slate.consume(config, 'fixed_input_size'),
hidden_size=self.slate.consume(config, 'hidden_size')
)
else:
raise ValueError(f"Unknown model type: {model_type}")
def get_encoder(self, config):
return ArithmeticEncoder()
def run(self, run, forceNoProfile=False):
train_model(self.model, self.train_data, self.test_data, self.epochs, self.batch_size, self.learning_rate, self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path)
if __name__ == '__main__':
slate = Slate({'spikey': SpikeRunner})
slate.from_args()

98
model.py Normal file
View File

@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
class BaseModel(ABC, nn.Module):
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def encode(self, data):
pass
@abstractmethod
def decode(self, encoded_data):
pass
class LSTMPredictor(BaseModel):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTMPredictor, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
self.hidden_size = hidden_size
self.num_layers = num_layers
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, _ = self.rnn(x, (h0, c0))
out = self.fc(out)
return out
def encode(self, data):
self.eval()
encoded_data = []
with torch.no_grad():
for i in range(len(data) - 1):
context = torch.tensor(data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
prediction = self.forward(context).item()
delta = data[i] - prediction
encoded_data.append(delta)
return encoded_data
def decode(self, encoded_data):
self.eval()
decoded_data = []
with torch.no_grad():
for i in range(len(encoded_data)):
context = torch.tensor(decoded_data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
prediction = self.forward(context).item()
decoded_data.append(prediction + encoded_data[i])
return decoded_data
class FixedInputNNPredictor(BaseModel):
def __init__(self, input_size, hidden_size):
super(FixedInputNNPredictor, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, 1)
self.input_size = input_size
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def encode(self, data):
self.eval()
encoded_data = []
with torch.no_grad():
for i in range(len(data) - self.input_size):
context = torch.tensor(data[i:i + self.input_size]).view(1, -1).float()
prediction = self.forward(context).item()
delta = data[i + self.input_size] - prediction
encoded_data.append(delta)
return encoded_data
def decode(self, encoded_data):
self.eval()
decoded_data = []
with torch.no_grad():
for i in range(len(encoded_data)):
context = torch.tensor(decoded_data[max(0, i - self.input_size):i]).view(1, -1).float()
prediction = self.forward(context).item()
decoded_data.append(prediction + encoded_data[i])
return decoded_data

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
torch
numpy
scipy
matplotlib
wandb
pyyaml

113
train.py Normal file
View File

@ -0,0 +1,113 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import random
import os
import pickle
from data_processing import delta_encode, delta_decode, save_wav
from utils import visualize_prediction, plot_delta_distribution
from bitstream import ArithmeticEncoder
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0, num_points=None):
compression_ratios = []
identical_count = 0
all_deltas = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for file_data in data:
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
encoded_data = model(file_data).squeeze(1).cpu().detach().numpy().tolist()
encoder.build_model(encoded_data)
compressed_data = encoder.encode(encoded_data)
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
# Check equivalence
if use_delta_encoding:
decompressed_data = delta_decode(decompressed_data)
identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
if identical:
identical_count += 1
compression_ratio = len(file_data) / len(compressed_data)
compression_ratios.append(compression_ratio)
# Compute and collect deltas
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(device)).squeeze(1).cpu().detach().numpy().tolist()
if use_delta_encoding:
predicted_data = delta_decode(predicted_data)
delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
all_deltas.extend(delta_data)
# Visualize prediction vs data vs error
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, num_points)
identical_percentage = (identical_count / len(data)) * 100
# Plot delta distribution
delta_plot_path = plot_delta_distribution(all_deltas, epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)})
return compression_ratios, identical_percentage
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path, num_points=None):
"""Train the model."""
wandb.init(project="wav-compression")
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_test_score = float('inf')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
for epoch in range(epochs):
total_loss = 0
random.shuffle(train_data) # Shuffle data for varied batches
for i in range(0, len(train_data) - batch_size, batch_size):
inputs = torch.tensor(train_data[i:i+batch_size], dtype=torch.float32).unsqueeze(2).to(device)
targets = torch.tensor(train_data[i+1:i+batch_size+1], dtype=torch.float32).unsqueeze(2).to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
wandb.log({"epoch": epoch, "loss": total_loss})
print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
if (epoch + 1) % eval_freq == 0:
# Evaluate on train and test data
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
# Log statistics
wandb.log({
"train_compression_ratio_mean": np.mean(train_compression_ratios),
"train_compression_ratio_std": np.std(train_compression_ratios),
"train_compression_ratio_min": np.min(train_compression_ratios),
"train_compression_ratio_max": np.max(train_compression_ratios),
"test_compression_ratio_mean": np.mean(test_compression_ratios),
"test_compression_ratio_std": np.std(test_compression_ratios),
"test_compression_ratio_min": np.min(test_compression_ratios),
"test_compression_ratio_max": np.max(test_compression_ratios),
"train_identical_percentage": train_identical_percentage,
"test_identical_percentage": test_identical_percentage,
})
print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
# Save model and encoder if new highscore on test data
test_score = np.mean(test_compression_ratios)
if test_score < best_test_score:
best_test_score = test_score
model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt")
encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl")
torch.save(model.state_dict(), model_path)
with open(encoder_path, 'wb') as f:
pickle.dump(encoder, f)
print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}')

63
utils.py Normal file
View File

@ -0,0 +1,63 @@
import matplotlib.pyplot as plt
import numpy as np
import wandb
import os
def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
"""Visualize WAV data using matplotlib."""
if num_points:
data = data[:num_points]
plt.figure(figsize=(10, 4))
plt.plot(np.linspace(0, len(data) / sample_rate, num=len(data)), data)
plt.title(title)
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.show()
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None):
"""Visualize the true data, predicted data, and deltas."""
if num_points:
true_data = true_data[:num_points]
predicted_data = predicted_data[:num_points]
delta_data = delta_data[:num_points]
plt.figure(figsize=(15, 5))
plt.subplot(3, 1, 1)
plt.plot(true_data, label='True Data')
plt.title('True Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.subplot(3, 1, 2)
plt.plot(predicted_data, label='Predicted Data', color='orange')
plt.title('Predicted Data')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.subplot(3, 1, 3)
plt.plot(delta_data, label='Delta', color='red')
plt.title('Delta')
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.tight_layout()
tmp_dir = os.getenv('TMPDIR', '/tmp')
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
plt.savefig(file_path)
plt.close()
wandb.log({"Prediction vs True Data": wandb.Image(file_path)})
def plot_delta_distribution(deltas, epoch):
"""Plot the distribution of deltas."""
plt.figure(figsize=(10, 6))
plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g')
plt.title(f'Delta Distribution at Epoch {epoch}')
plt.xlabel('Delta')
plt.ylabel('Density')
plt.grid(True)
tmp_dir = os.getenv('TMPDIR', '/tmp')
file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png')
plt.savefig(file_path)
plt.close()
return file_path