commit 0c5f888d755dfed6fd828b80698a424f9d33a7bc Author: Dominik Roth Date: Fri May 24 22:01:59 2024 +0200 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0a195e6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +data +data.zip +__pycache__ +.venv +wandb +slurm_log +job_hist.log +models +Xvfb.log +profiler diff --git a/README.md b/README.md new file mode 100644 index 0000000..ce1ed85 --- /dev/null +++ b/README.md @@ -0,0 +1,76 @@ +# Spikey + +This repository contains a solution for the [Neuralink Compression Challenge](https://content.neuralink.com/compression-challenge/README.html). The challenge involves compressing raw electrode recordings from a Neuralink implant. These recordings are taken from the motor cortex of a non-human primate while playing a video game. + +## Challenge Overview + +The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio). + +## Installation + +To install the necessary dependencies, create a virtual environment and install the requirements: + +```bash +python3 -m venv env +source env/bin/activate +pip install -r requirements.txt +``` + +## Usage + +### Configuration + +The configuration for training and evaluation is specified in a YAML file. Below is an example configuration: + +```yaml +name: Test + +preprocessing: + use_delta_encoding: true # Whether to use delta encoding. + +predictor: + type: lstm # Options: 'lstm', 'fixed_input_nn' + input_size: 1 # Input size for the LSTM predictor. + hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor. + num_layers: 2 # Number of layers for the LSTM predictor. + fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'. + +training: + epochs: 10 # Number of training epochs. + batch_size: 32 # Batch size for training. + learning_rate: 0.001 # Learning rate for the optimizer. + eval_freq: 2 # Frequency of evaluation during training (in epochs). + save_path: models # Directory to save the best model and encoder. + num_points: 1000 # Number of data points to visualize. + +bitstream_encoding: + type: arithmetic # Use arithmetic encoding. + +data: + url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset. + directory: data # Directory to extract and store the dataset. + split_ratio: 0.8 # Ratio to split the data into train and test sets. +``` + +### Running the Code + +To train the model and compress/decompress WAV files, use the CLI provided: + +```bash +python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin +python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav +``` + +### Training + +Requires Slate, which is not currently publicaly avaible. Install via (requires repo access) + +```bash +pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate +``` + +To train the model, run: + +```bash +python main.py config.yaml Test +``` diff --git a/bitstream.py b/bitstream.py new file mode 100644 index 0000000..8ce2fa0 --- /dev/null +++ b/bitstream.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from arithmetic_compressor import AECompressor +from arithmetic_compressor.models import StaticModel + +class BaseEncoder(ABC): + @abstractmethod + def encode(self, data): + pass + + @abstractmethod + def decode(self, encoded_data, num_symbols): + pass + + @abstractmethod + def build_model(self, data): + pass + +class ArithmeticEncoder(BaseEncoder): + def encode(self, data): + if not hasattr(self, 'model'): + raise ValueError("Model not built. Call build_model(data) before encoding.") + coder = AECompressor(self.model) + compressed_data = coder.compress(data) + return compressed_data + + def decode(self, encoded_data, num_symbols): + coder = AECompressor(self.model) + decoded_data = coder.decompress(encoded_data, num_symbols) + return decoded_data + + def build_model(self, data): + symbol_counts = {symbol: data.count(symbol) for symbol in set(data)} + total_symbols = sum(symbol_counts.values()) + probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()} + self.model = StaticModel(probabilities) diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..2388983 --- /dev/null +++ b/cli.py @@ -0,0 +1,50 @@ +import argparse +import yaml +import os +import torch +from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode +from main import SpikeRunner + +def load_config(config_path): + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + return config + +def main(): + parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks") + parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform") + parser.add_argument('--config', default='config.yaml', help="Path to the configuration file") + parser.add_argument('--input_file', help="Path to the input WAV file") + parser.add_argument('--output_file', help="Path to the output file") + args = parser.parse_args() + + config = load_config(args.config) + + spike_runner = SpikeRunner(None, config) + spike_runner.setup('CLI') + + if args.action == 'compress': + data = load_all_wavs(args.input_file) + if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'): + data = [delta_encode(d) for d in data] + + spike_runner.encoder.build_model(data) + encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data] + compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data] + + with open(args.output_file, 'wb') as f: + for cd in compressed_data: + f.write(bytearray(cd)) + + elif args.action == 'decompress': + with open(args.input_file, 'rb') as f: + compressed_data = list(f.read()) + + decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data] + if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'): + decoded_data = [delta_decode(dd) for dd in decoded_data] + + save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate + +if __name__ == "__main__": + main() diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..91a9015 --- /dev/null +++ b/config.yaml @@ -0,0 +1,67 @@ +name: DEFAULT +project: Spikey + +slurm: + name: 'Spikey_{config[name]}' + partitions: + - single + standard_output: ./reports/slurm/out_%A_%a.log + standard_error: ./reports/slurm/err_%A_%a.log + num_parallel_jobs: 50 + cpus_per_task: 4 + memory_per_cpu: 1000 + time_limit: 1440 # in minutes + ntasks: 1 + venv: '.venv/bin/activate' + sh_lines: + - 'mkdir -p {tmp}/wandb' + - 'mkdir -p {tmp}/local_pycache' + - 'export PYTHONPYCACHEPREFIX={tmp}/local_pycache' + +runner: spikey + +scheduler: + reps_per_version: 1 + agents_per_job: 1 + reps_per_agent: 1 + +wandb: + project: '{config[project]}' + group: '{config[name]}' + job_type: '{delta_desc}' + name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}' + tags: + - '{config[env][name]}' + - '{config[algo][name]}' + sync_tensorboard: False + monitor_gym: False + save_code: False + +--- +name: Test + +preprocessing: + use_delta_encoding: true # Whether to use delta encoding. + +predictor: + type: lstm # Options: 'lstm', 'fixed_input_nn' + input_size: 1 # Input size for the LSTM predictor. + hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor. + num_layers: 2 # Number of layers for the LSTM predictor. + fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'. + +training: + epochs: 10 # Number of training epochs. + batch_size: 32 # Batch size for training. + learning_rate: 0.001 # Learning rate for the optimizer. + eval_freq: 2 # Frequency of evaluation during training (in epochs). + save_path: models # Directory to save the best model and encoder. + num_points: 1000 # Number of data points to visualize + +bitstream_encoding: + type: arithmetic # Use arithmetic encoding. + +data: + url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset. + directory: data # Directory to extract and store the dataset. + split_ratio: 0.8 # Ratio to split the data into train and test sets. diff --git a/data_processing.py b/data_processing.py new file mode 100644 index 0000000..5311105 --- /dev/null +++ b/data_processing.py @@ -0,0 +1,46 @@ +import os +import numpy as np +from scipy.io import wavfile +import urllib.request +import zipfile + +def download_and_extract_data(url, data_dir): + if not os.path.exists(data_dir): + os.makedirs(data_dir) + zip_path = os.path.join(data_dir, 'data.zip') + urllib.request.urlretrieve(url, zip_path) + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(data_dir) + os.remove(zip_path) + +def load_wav(file_path): + """Load WAV file and return sample rate and data.""" + sample_rate, data = wavfile.read(file_path) + return sample_rate, data + +def load_all_wavs(data_dir): + """Load all WAV files in the given directory.""" + wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')] + all_data = [] + for file_path in wav_files: + _, data = load_wav(file_path) + all_data.append(data) + return all_data + +def save_wav(file_path, sample_rate, data): + """Save data to a WAV file.""" + wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32)) + +def delta_encode(data): + """Apply delta encoding to the data.""" + deltas = [data[0]] + for i in range(1, len(data)): + deltas.append(data[i] - data[i - 1]) + return deltas + +def delta_decode(deltas): + """Decode delta encoded data.""" + data = [deltas[0]] + for i in range(1, len(deltas)): + data.append(data[-1] + deltas[i]) + return data diff --git a/main.py b/main.py new file mode 100644 index 0000000..5bee185 --- /dev/null +++ b/main.py @@ -0,0 +1,69 @@ +import yaml +from slate import Slate, Slate_Runner +from data_processing import download_and_extract_data, load_all_wavs, delta_encode +from model import LSTMPredictor, FixedInputNNPredictor +from train import train_model +from bitstream import ArithmeticEncoder + +class SpikeRunner(Slate_Runner): + def setup(self, name): + self.name = name + slate, config = self.slate, self.config + + # Consume config sections + preprocessing_config = slate.consume(config, 'preprocessing', expand=True) + predictor_config = slate.consume(config, 'predictor', expand=True) + training_config = slate.consume(config, 'training', expand=True) + bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True) + data_config = slate.consume(config, 'data', expand=True) + + # Data setup + data_url = slate.consume(data_config, 'url') + data_dir = slate.consume(data_config, 'directory') + download_and_extract_data(data_url, data_dir) + all_data = load_all_wavs(data_dir) + + if slate.consume(preprocessing_config, 'use_delta_encoding'): + all_data = [delta_encode(d) for d in all_data] + + # Split data into train and test sets + split_ratio = slate.consume(data_config, 'split_ratio', 0.8) + split_idx = int(len(all_data) * split_ratio) + self.train_data = all_data[:split_idx] + self.test_data = all_data[split_idx:] + + # Model setup + self.model = self.get_model(predictor_config) + self.encoder = self.get_encoder(bitstream_config) + self.epochs = slate.consume(training_config, 'epochs') + self.batch_size = slate.consume(training_config, 'batch_size') + self.learning_rate = slate.consume(training_config, 'learning_rate') + self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding') + self.eval_freq = slate.consume(training_config, 'eval_freq') + self.save_path = slate.consume(training_config, 'save_path', 'models') + + def get_model(self, config): + model_type = self.slate.consume(config, 'type') + if model_type == 'lstm': + return LSTMPredictor( + input_size=self.slate.consume(config, 'input_size'), + hidden_size=self.slate.consume(config, 'hidden_size'), + num_layers=self.slate.consume(config, 'num_layers') + ) + elif model_type == 'fixed_input_nn': + return FixedInputNNPredictor( + input_size=self.slate.consume(config, 'fixed_input_size'), + hidden_size=self.slate.consume(config, 'hidden_size') + ) + else: + raise ValueError(f"Unknown model type: {model_type}") + + def get_encoder(self, config): + return ArithmeticEncoder() + + def run(self, run, forceNoProfile=False): + train_model(self.model, self.train_data, self.test_data, self.epochs, self.batch_size, self.learning_rate, self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path) + +if __name__ == '__main__': + slate = Slate({'spikey': SpikeRunner}) + slate.from_args() diff --git a/model.py b/model.py new file mode 100644 index 0000000..a0e81c5 --- /dev/null +++ b/model.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +from abc import ABC, abstractmethod + +class BaseModel(ABC, nn.Module): + def __init__(self): + super().__init__() + + @abstractmethod + def forward(self, x): + pass + + @abstractmethod + def encode(self, data): + pass + + @abstractmethod + def decode(self, encoded_data): + pass + +class LSTMPredictor(BaseModel): + def __init__(self, input_size, hidden_size, num_layers): + super(LSTMPredictor, self).__init__() + self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + self.fc = nn.Linear(hidden_size, 1) + self.hidden_size = hidden_size + self.num_layers = num_layers + + def forward(self, x): + h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) + c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) + out, _ = self.rnn(x, (h0, c0)) + out = self.fc(out) + return out + + def encode(self, data): + self.eval() + encoded_data = [] + + with torch.no_grad(): + for i in range(len(data) - 1): + context = torch.tensor(data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float() + prediction = self.forward(context).item() + delta = data[i] - prediction + encoded_data.append(delta) + + return encoded_data + + def decode(self, encoded_data): + self.eval() + decoded_data = [] + + with torch.no_grad(): + for i in range(len(encoded_data)): + context = torch.tensor(decoded_data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float() + prediction = self.forward(context).item() + decoded_data.append(prediction + encoded_data[i]) + + return decoded_data + +class FixedInputNNPredictor(BaseModel): + def __init__(self, input_size, hidden_size): + super(FixedInputNNPredictor, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, 1) + self.input_size = input_size + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + return x + + def encode(self, data): + self.eval() + encoded_data = [] + + with torch.no_grad(): + for i in range(len(data) - self.input_size): + context = torch.tensor(data[i:i + self.input_size]).view(1, -1).float() + prediction = self.forward(context).item() + delta = data[i + self.input_size] - prediction + encoded_data.append(delta) + + return encoded_data + + def decode(self, encoded_data): + self.eval() + decoded_data = [] + + with torch.no_grad(): + for i in range(len(encoded_data)): + context = torch.tensor(decoded_data[max(0, i - self.input_size):i]).view(1, -1).float() + prediction = self.forward(context).item() + decoded_data.append(prediction + encoded_data[i]) + + return decoded_data diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ba358e8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +torch +numpy +scipy +matplotlib +wandb +pyyaml \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..3cf4cf9 --- /dev/null +++ b/train.py @@ -0,0 +1,113 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import wandb +import random +import os +import pickle +from data_processing import delta_encode, delta_decode, save_wav +from utils import visualize_prediction, plot_delta_distribution +from bitstream import ArithmeticEncoder + +def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0, num_points=None): + compression_ratios = [] + identical_count = 0 + all_deltas = [] + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + for file_data in data: + file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device) + encoded_data = model(file_data).squeeze(1).cpu().detach().numpy().tolist() + encoder.build_model(encoded_data) + compressed_data = encoder.encode(encoded_data) + decompressed_data = encoder.decode(compressed_data, len(encoded_data)) + + # Check equivalence + if use_delta_encoding: + decompressed_data = delta_decode(decompressed_data) + identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5) + if identical: + identical_count += 1 + + compression_ratio = len(file_data) / len(compressed_data) + compression_ratios.append(compression_ratio) + + # Compute and collect deltas + predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(device)).squeeze(1).cpu().detach().numpy().tolist() + if use_delta_encoding: + predicted_data = delta_decode(predicted_data) + delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))] + all_deltas.extend(delta_data) + + # Visualize prediction vs data vs error + visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, num_points) + + identical_percentage = (identical_count / len(data)) * 100 + + # Plot delta distribution + delta_plot_path = plot_delta_distribution(all_deltas, epoch) + wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}) + + return compression_ratios, identical_percentage + +def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path, num_points=None): + """Train the model.""" + wandb.init(project="wav-compression") + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + best_test_score = float('inf') + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model.to(device) + + for epoch in range(epochs): + total_loss = 0 + random.shuffle(train_data) # Shuffle data for varied batches + for i in range(0, len(train_data) - batch_size, batch_size): + inputs = torch.tensor(train_data[i:i+batch_size], dtype=torch.float32).unsqueeze(2).to(device) + targets = torch.tensor(train_data[i+1:i+batch_size+1], dtype=torch.float32).unsqueeze(2).to(device) + outputs = model(inputs) + loss = criterion(outputs, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() + + wandb.log({"epoch": epoch, "loss": total_loss}) + print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}') + + if (epoch + 1) % eval_freq == 0: + # Evaluate on train and test data + train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points) + test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points) + + # Log statistics + wandb.log({ + "train_compression_ratio_mean": np.mean(train_compression_ratios), + "train_compression_ratio_std": np.std(train_compression_ratios), + "train_compression_ratio_min": np.min(train_compression_ratios), + "train_compression_ratio_max": np.max(train_compression_ratios), + "test_compression_ratio_mean": np.mean(test_compression_ratios), + "test_compression_ratio_std": np.std(test_compression_ratios), + "test_compression_ratio_min": np.min(test_compression_ratios), + "test_compression_ratio_max": np.max(test_compression_ratios), + "train_identical_percentage": train_identical_percentage, + "test_identical_percentage": test_identical_percentage, + }) + + print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%') + print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%') + + # Save model and encoder if new highscore on test data + test_score = np.mean(test_compression_ratios) + if test_score < best_test_score: + best_test_score = test_score + model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt") + encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl") + torch.save(model.state_dict(), model_path) + with open(encoder_path, 'wb') as f: + pickle.dump(encoder, f) + print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}') diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..2932217 --- /dev/null +++ b/utils.py @@ -0,0 +1,63 @@ +import matplotlib.pyplot as plt +import numpy as np +import wandb +import os + +def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None): + """Visualize WAV data using matplotlib.""" + if num_points: + data = data[:num_points] + plt.figure(figsize=(10, 4)) + plt.plot(np.linspace(0, len(data) / sample_rate, num=len(data)), data) + plt.title(title) + plt.xlabel('Time [s]') + plt.ylabel('Amplitude') + plt.show() + +def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None): + """Visualize the true data, predicted data, and deltas.""" + if num_points: + true_data = true_data[:num_points] + predicted_data = predicted_data[:num_points] + delta_data = delta_data[:num_points] + + plt.figure(figsize=(15, 5)) + + plt.subplot(3, 1, 1) + plt.plot(true_data, label='True Data') + plt.title('True Data') + plt.xlabel('Sample') + plt.ylabel('Amplitude') + + plt.subplot(3, 1, 2) + plt.plot(predicted_data, label='Predicted Data', color='orange') + plt.title('Predicted Data') + plt.xlabel('Sample') + plt.ylabel('Amplitude') + + plt.subplot(3, 1, 3) + plt.plot(delta_data, label='Delta', color='red') + plt.title('Delta') + plt.xlabel('Sample') + plt.ylabel('Amplitude') + + plt.tight_layout() + tmp_dir = os.getenv('TMPDIR', '/tmp') + file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png') + plt.savefig(file_path) + plt.close() + wandb.log({"Prediction vs True Data": wandb.Image(file_path)}) + +def plot_delta_distribution(deltas, epoch): + """Plot the distribution of deltas.""" + plt.figure(figsize=(10, 6)) + plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g') + plt.title(f'Delta Distribution at Epoch {epoch}') + plt.xlabel('Delta') + plt.ylabel('Density') + plt.grid(True) + tmp_dir = os.getenv('TMPDIR', '/tmp') + file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png') + plt.savefig(file_path) + plt.close() + return file_path