initial commit
This commit is contained in:
		
						commit
						0c5f888d75
					
				
							
								
								
									
										10
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
data
 | 
			
		||||
data.zip
 | 
			
		||||
__pycache__
 | 
			
		||||
.venv
 | 
			
		||||
wandb
 | 
			
		||||
slurm_log
 | 
			
		||||
job_hist.log
 | 
			
		||||
models
 | 
			
		||||
Xvfb.log
 | 
			
		||||
profiler
 | 
			
		||||
							
								
								
									
										76
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
			
		||||
# Spikey
 | 
			
		||||
 | 
			
		||||
This repository contains a solution for the [Neuralink Compression Challenge](https://content.neuralink.com/compression-challenge/README.html). The challenge involves compressing raw electrode recordings from a Neuralink implant. These recordings are taken from the motor cortex of a non-human primate while playing a video game.
 | 
			
		||||
 | 
			
		||||
## Challenge Overview
 | 
			
		||||
 | 
			
		||||
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
 | 
			
		||||
 | 
			
		||||
## Installation
 | 
			
		||||
 | 
			
		||||
To install the necessary dependencies, create a virtual environment and install the requirements:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m venv env
 | 
			
		||||
source env/bin/activate
 | 
			
		||||
pip install -r requirements.txt
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Usage
 | 
			
		||||
 | 
			
		||||
### Configuration
 | 
			
		||||
 | 
			
		||||
The configuration for training and evaluation is specified in a YAML file. Below is an example configuration:
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
name: Test
 | 
			
		||||
 | 
			
		||||
preprocessing:
 | 
			
		||||
  use_delta_encoding: true # Whether to use delta encoding.
 | 
			
		||||
 | 
			
		||||
predictor:
 | 
			
		||||
  type: lstm # Options: 'lstm', 'fixed_input_nn'
 | 
			
		||||
  input_size: 1 # Input size for the LSTM predictor.
 | 
			
		||||
  hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
 | 
			
		||||
  num_layers: 2 # Number of layers for the LSTM predictor.
 | 
			
		||||
  fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
 | 
			
		||||
 | 
			
		||||
training:
 | 
			
		||||
  epochs: 10 # Number of training epochs.
 | 
			
		||||
  batch_size: 32 # Batch size for training.
 | 
			
		||||
  learning_rate: 0.001 # Learning rate for the optimizer.
 | 
			
		||||
  eval_freq: 2 # Frequency of evaluation during training (in epochs).
 | 
			
		||||
  save_path: models # Directory to save the best model and encoder.
 | 
			
		||||
  num_points: 1000 # Number of data points to visualize.
 | 
			
		||||
 | 
			
		||||
bitstream_encoding:
 | 
			
		||||
  type: arithmetic # Use arithmetic encoding.
 | 
			
		||||
 | 
			
		||||
data:
 | 
			
		||||
  url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
 | 
			
		||||
  directory: data # Directory to extract and store the dataset.
 | 
			
		||||
  split_ratio: 0.8 # Ratio to split the data into train and test sets.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Running the Code
 | 
			
		||||
 | 
			
		||||
To train the model and compress/decompress WAV files, use the CLI provided:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin
 | 
			
		||||
python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Training
 | 
			
		||||
 | 
			
		||||
Requires Slate, which is not currently publicaly avaible. Install via (requires repo access)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
To train the model, run:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python main.py config.yaml Test
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										35
									
								
								bitstream.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								bitstream.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from arithmetic_compressor import AECompressor
 | 
			
		||||
from arithmetic_compressor.models import StaticModel
 | 
			
		||||
 | 
			
		||||
class BaseEncoder(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def encode(self, data):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def decode(self, encoded_data, num_symbols):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def build_model(self, data):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
class ArithmeticEncoder(BaseEncoder):
 | 
			
		||||
    def encode(self, data):
 | 
			
		||||
        if not hasattr(self, 'model'):
 | 
			
		||||
            raise ValueError("Model not built. Call build_model(data) before encoding.")
 | 
			
		||||
        coder = AECompressor(self.model)
 | 
			
		||||
        compressed_data = coder.compress(data)
 | 
			
		||||
        return compressed_data
 | 
			
		||||
 | 
			
		||||
    def decode(self, encoded_data, num_symbols):
 | 
			
		||||
        coder = AECompressor(self.model)
 | 
			
		||||
        decoded_data = coder.decompress(encoded_data, num_symbols)
 | 
			
		||||
        return decoded_data
 | 
			
		||||
 | 
			
		||||
    def build_model(self, data):
 | 
			
		||||
        symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
 | 
			
		||||
        total_symbols = sum(symbol_counts.values())
 | 
			
		||||
        probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
 | 
			
		||||
        self.model = StaticModel(probabilities)
 | 
			
		||||
							
								
								
									
										50
									
								
								cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								cli.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import yaml
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
 | 
			
		||||
from main import SpikeRunner
 | 
			
		||||
 | 
			
		||||
def load_config(config_path):
 | 
			
		||||
    with open(config_path, 'r') as file:
 | 
			
		||||
        config = yaml.safe_load(file)
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
 | 
			
		||||
    parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
 | 
			
		||||
    parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
 | 
			
		||||
    parser.add_argument('--input_file', help="Path to the input WAV file")
 | 
			
		||||
    parser.add_argument('--output_file', help="Path to the output file")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    
 | 
			
		||||
    config = load_config(args.config)
 | 
			
		||||
    
 | 
			
		||||
    spike_runner = SpikeRunner(None, config)
 | 
			
		||||
    spike_runner.setup('CLI')
 | 
			
		||||
 | 
			
		||||
    if args.action == 'compress':
 | 
			
		||||
        data = load_all_wavs(args.input_file)
 | 
			
		||||
        if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
 | 
			
		||||
            data = [delta_encode(d) for d in data]
 | 
			
		||||
        
 | 
			
		||||
        spike_runner.encoder.build_model(data)
 | 
			
		||||
        encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
 | 
			
		||||
        compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
 | 
			
		||||
        
 | 
			
		||||
        with open(args.output_file, 'wb') as f:
 | 
			
		||||
            for cd in compressed_data:
 | 
			
		||||
                f.write(bytearray(cd))
 | 
			
		||||
    
 | 
			
		||||
    elif args.action == 'decompress':
 | 
			
		||||
        with open(args.input_file, 'rb') as f:
 | 
			
		||||
            compressed_data = list(f.read())
 | 
			
		||||
        
 | 
			
		||||
        decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
 | 
			
		||||
        if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
 | 
			
		||||
            decoded_data = [delta_decode(dd) for dd in decoded_data]
 | 
			
		||||
        
 | 
			
		||||
        save_wav(args.output_file, 19531, decoded_data)  # Assuming 19531 Hz sample rate
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										67
									
								
								config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								config.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,67 @@
 | 
			
		||||
name: DEFAULT
 | 
			
		||||
project: Spikey
 | 
			
		||||
 | 
			
		||||
slurm:
 | 
			
		||||
  name: 'Spikey_{config[name]}'
 | 
			
		||||
  partitions:
 | 
			
		||||
    - single
 | 
			
		||||
  standard_output: ./reports/slurm/out_%A_%a.log
 | 
			
		||||
  standard_error: ./reports/slurm/err_%A_%a.log
 | 
			
		||||
  num_parallel_jobs: 50
 | 
			
		||||
  cpus_per_task: 4
 | 
			
		||||
  memory_per_cpu: 1000
 | 
			
		||||
  time_limit: 1440  # in minutes
 | 
			
		||||
  ntasks: 1
 | 
			
		||||
  venv: '.venv/bin/activate'
 | 
			
		||||
  sh_lines:
 | 
			
		||||
    - 'mkdir -p {tmp}/wandb'
 | 
			
		||||
    - 'mkdir -p {tmp}/local_pycache'
 | 
			
		||||
    - 'export PYTHONPYCACHEPREFIX={tmp}/local_pycache'
 | 
			
		||||
 | 
			
		||||
runner: spikey
 | 
			
		||||
 | 
			
		||||
scheduler:
 | 
			
		||||
  reps_per_version: 1
 | 
			
		||||
  agents_per_job: 1
 | 
			
		||||
  reps_per_agent: 1
 | 
			
		||||
 | 
			
		||||
wandb:
 | 
			
		||||
  project: '{config[project]}'
 | 
			
		||||
  group: '{config[name]}'
 | 
			
		||||
  job_type: '{delta_desc}'
 | 
			
		||||
  name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}'
 | 
			
		||||
  tags:
 | 
			
		||||
    - '{config[env][name]}'
 | 
			
		||||
    - '{config[algo][name]}'
 | 
			
		||||
  sync_tensorboard: False
 | 
			
		||||
  monitor_gym: False
 | 
			
		||||
  save_code: False
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
name: Test
 | 
			
		||||
 | 
			
		||||
preprocessing:
 | 
			
		||||
  use_delta_encoding: true  # Whether to use delta encoding.
 | 
			
		||||
 | 
			
		||||
predictor:
 | 
			
		||||
  type: lstm  # Options: 'lstm', 'fixed_input_nn'
 | 
			
		||||
  input_size: 1  # Input size for the LSTM predictor.
 | 
			
		||||
  hidden_size: 128  # Hidden size for the LSTM or Fixed Input NN predictor.
 | 
			
		||||
  num_layers: 2  # Number of layers for the LSTM predictor.
 | 
			
		||||
  fixed_input_size: 10  # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
 | 
			
		||||
 | 
			
		||||
training:
 | 
			
		||||
  epochs: 10  # Number of training epochs.
 | 
			
		||||
  batch_size: 32  # Batch size for training.
 | 
			
		||||
  learning_rate: 0.001  # Learning rate for the optimizer.
 | 
			
		||||
  eval_freq: 2  # Frequency of evaluation during training (in epochs).
 | 
			
		||||
  save_path: models  # Directory to save the best model and encoder.
 | 
			
		||||
  num_points: 1000  # Number of data points to visualize
 | 
			
		||||
 | 
			
		||||
bitstream_encoding:
 | 
			
		||||
  type: arithmetic  # Use arithmetic encoding.
 | 
			
		||||
 | 
			
		||||
data:
 | 
			
		||||
  url: https://content.neuralink.com/compression-challenge/data.zip  # URL to download the dataset.
 | 
			
		||||
  directory: data  # Directory to extract and store the dataset.
 | 
			
		||||
  split_ratio: 0.8  # Ratio to split the data into train and test sets.
 | 
			
		||||
							
								
								
									
										46
									
								
								data_processing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								data_processing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
import os
 | 
			
		||||
import numpy as np
 | 
			
		||||
from scipy.io import wavfile
 | 
			
		||||
import urllib.request
 | 
			
		||||
import zipfile
 | 
			
		||||
 | 
			
		||||
def download_and_extract_data(url, data_dir):
 | 
			
		||||
    if not os.path.exists(data_dir):
 | 
			
		||||
        os.makedirs(data_dir)
 | 
			
		||||
        zip_path = os.path.join(data_dir, 'data.zip')
 | 
			
		||||
        urllib.request.urlretrieve(url, zip_path)
 | 
			
		||||
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
 | 
			
		||||
            zip_ref.extractall(data_dir)
 | 
			
		||||
        os.remove(zip_path)
 | 
			
		||||
 | 
			
		||||
def load_wav(file_path):
 | 
			
		||||
    """Load WAV file and return sample rate and data."""
 | 
			
		||||
    sample_rate, data = wavfile.read(file_path)
 | 
			
		||||
    return sample_rate, data
 | 
			
		||||
 | 
			
		||||
def load_all_wavs(data_dir):
 | 
			
		||||
    """Load all WAV files in the given directory."""
 | 
			
		||||
    wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
 | 
			
		||||
    all_data = []
 | 
			
		||||
    for file_path in wav_files:
 | 
			
		||||
        _, data = load_wav(file_path)
 | 
			
		||||
        all_data.append(data)
 | 
			
		||||
    return all_data
 | 
			
		||||
 | 
			
		||||
def save_wav(file_path, sample_rate, data):
 | 
			
		||||
    """Save data to a WAV file."""
 | 
			
		||||
    wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32))
 | 
			
		||||
 | 
			
		||||
def delta_encode(data):
 | 
			
		||||
    """Apply delta encoding to the data."""
 | 
			
		||||
    deltas = [data[0]]
 | 
			
		||||
    for i in range(1, len(data)):
 | 
			
		||||
        deltas.append(data[i] - data[i - 1])
 | 
			
		||||
    return deltas
 | 
			
		||||
 | 
			
		||||
def delta_decode(deltas):
 | 
			
		||||
    """Decode delta encoded data."""
 | 
			
		||||
    data = [deltas[0]]
 | 
			
		||||
    for i in range(1, len(deltas)):
 | 
			
		||||
        data.append(data[-1] + deltas[i])
 | 
			
		||||
    return data
 | 
			
		||||
							
								
								
									
										69
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
import yaml
 | 
			
		||||
from slate import Slate, Slate_Runner
 | 
			
		||||
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
 | 
			
		||||
from model import LSTMPredictor, FixedInputNNPredictor
 | 
			
		||||
from train import train_model
 | 
			
		||||
from bitstream import ArithmeticEncoder
 | 
			
		||||
 | 
			
		||||
class SpikeRunner(Slate_Runner):
 | 
			
		||||
    def setup(self, name):
 | 
			
		||||
        self.name = name
 | 
			
		||||
        slate, config = self.slate, self.config
 | 
			
		||||
 | 
			
		||||
        # Consume config sections
 | 
			
		||||
        preprocessing_config = slate.consume(config, 'preprocessing', expand=True)
 | 
			
		||||
        predictor_config = slate.consume(config, 'predictor', expand=True)
 | 
			
		||||
        training_config = slate.consume(config, 'training', expand=True)
 | 
			
		||||
        bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True)
 | 
			
		||||
        data_config = slate.consume(config, 'data', expand=True)
 | 
			
		||||
 | 
			
		||||
        # Data setup
 | 
			
		||||
        data_url = slate.consume(data_config, 'url')
 | 
			
		||||
        data_dir = slate.consume(data_config, 'directory')
 | 
			
		||||
        download_and_extract_data(data_url, data_dir)
 | 
			
		||||
        all_data = load_all_wavs(data_dir)
 | 
			
		||||
        
 | 
			
		||||
        if slate.consume(preprocessing_config, 'use_delta_encoding'):
 | 
			
		||||
            all_data = [delta_encode(d) for d in all_data]
 | 
			
		||||
 | 
			
		||||
        # Split data into train and test sets
 | 
			
		||||
        split_ratio = slate.consume(data_config, 'split_ratio', 0.8)
 | 
			
		||||
        split_idx = int(len(all_data) * split_ratio)
 | 
			
		||||
        self.train_data = all_data[:split_idx]
 | 
			
		||||
        self.test_data = all_data[split_idx:]
 | 
			
		||||
        
 | 
			
		||||
        # Model setup
 | 
			
		||||
        self.model = self.get_model(predictor_config)
 | 
			
		||||
        self.encoder = self.get_encoder(bitstream_config)
 | 
			
		||||
        self.epochs = slate.consume(training_config, 'epochs')
 | 
			
		||||
        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
			
		||||
        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
			
		||||
        self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
 | 
			
		||||
        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
			
		||||
        self.save_path = slate.consume(training_config, 'save_path', 'models')
 | 
			
		||||
 | 
			
		||||
    def get_model(self, config):
 | 
			
		||||
        model_type = self.slate.consume(config, 'type')
 | 
			
		||||
        if model_type == 'lstm':
 | 
			
		||||
            return LSTMPredictor(
 | 
			
		||||
                input_size=self.slate.consume(config, 'input_size'), 
 | 
			
		||||
                hidden_size=self.slate.consume(config, 'hidden_size'), 
 | 
			
		||||
                num_layers=self.slate.consume(config, 'num_layers')
 | 
			
		||||
            )
 | 
			
		||||
        elif model_type == 'fixed_input_nn':
 | 
			
		||||
            return FixedInputNNPredictor(
 | 
			
		||||
                input_size=self.slate.consume(config, 'fixed_input_size'), 
 | 
			
		||||
                hidden_size=self.slate.consume(config, 'hidden_size')
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown model type: {model_type}")
 | 
			
		||||
 | 
			
		||||
    def get_encoder(self, config):
 | 
			
		||||
        return ArithmeticEncoder()
 | 
			
		||||
 | 
			
		||||
    def run(self, run, forceNoProfile=False):
 | 
			
		||||
        train_model(self.model, self.train_data, self.test_data, self.epochs, self.batch_size, self.learning_rate, self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path)
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    slate = Slate({'spikey': SpikeRunner})
 | 
			
		||||
    slate.from_args()
 | 
			
		||||
							
								
								
									
										98
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,98 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
 | 
			
		||||
class BaseModel(ABC, nn.Module):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def encode(self, data):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def decode(self, encoded_data):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
class LSTMPredictor(BaseModel):
 | 
			
		||||
    def __init__(self, input_size, hidden_size, num_layers):
 | 
			
		||||
        super(LSTMPredictor, self).__init__()
 | 
			
		||||
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
 | 
			
		||||
        self.fc = nn.Linear(hidden_size, 1)
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.num_layers = num_layers
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
 | 
			
		||||
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
 | 
			
		||||
        out, _ = self.rnn(x, (h0, c0))
 | 
			
		||||
        out = self.fc(out)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def encode(self, data):
 | 
			
		||||
        self.eval()
 | 
			
		||||
        encoded_data = []
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for i in range(len(data) - 1):
 | 
			
		||||
                context = torch.tensor(data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
 | 
			
		||||
                prediction = self.forward(context).item()
 | 
			
		||||
                delta = data[i] - prediction
 | 
			
		||||
                encoded_data.append(delta)
 | 
			
		||||
        
 | 
			
		||||
        return encoded_data
 | 
			
		||||
 | 
			
		||||
    def decode(self, encoded_data):
 | 
			
		||||
        self.eval()
 | 
			
		||||
        decoded_data = []
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for i in range(len(encoded_data)):
 | 
			
		||||
                context = torch.tensor(decoded_data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
 | 
			
		||||
                prediction = self.forward(context).item()
 | 
			
		||||
                decoded_data.append(prediction + encoded_data[i])
 | 
			
		||||
        
 | 
			
		||||
        return decoded_data
 | 
			
		||||
 | 
			
		||||
class FixedInputNNPredictor(BaseModel):
 | 
			
		||||
    def __init__(self, input_size, hidden_size):
 | 
			
		||||
        super(FixedInputNNPredictor, self).__init__()
 | 
			
		||||
        self.fc1 = nn.Linear(input_size, hidden_size)
 | 
			
		||||
        self.relu = nn.ReLU()
 | 
			
		||||
        self.fc2 = nn.Linear(hidden_size, 1)
 | 
			
		||||
        self.input_size = input_size
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.fc1(x)
 | 
			
		||||
        x = self.relu(x)
 | 
			
		||||
        x = self.fc2(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def encode(self, data):
 | 
			
		||||
        self.eval()
 | 
			
		||||
        encoded_data = []
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for i in range(len(data) - self.input_size):
 | 
			
		||||
                context = torch.tensor(data[i:i + self.input_size]).view(1, -1).float()
 | 
			
		||||
                prediction = self.forward(context).item()
 | 
			
		||||
                delta = data[i + self.input_size] - prediction
 | 
			
		||||
                encoded_data.append(delta)
 | 
			
		||||
        
 | 
			
		||||
        return encoded_data
 | 
			
		||||
 | 
			
		||||
    def decode(self, encoded_data):
 | 
			
		||||
        self.eval()
 | 
			
		||||
        decoded_data = []
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for i in range(len(encoded_data)):
 | 
			
		||||
                context = torch.tensor(decoded_data[max(0, i - self.input_size):i]).view(1, -1).float()
 | 
			
		||||
                prediction = self.forward(context).item()
 | 
			
		||||
                decoded_data.append(prediction + encoded_data[i])
 | 
			
		||||
        
 | 
			
		||||
        return decoded_data
 | 
			
		||||
							
								
								
									
										6
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
torch
 | 
			
		||||
numpy
 | 
			
		||||
scipy
 | 
			
		||||
matplotlib
 | 
			
		||||
wandb
 | 
			
		||||
pyyaml
 | 
			
		||||
							
								
								
									
										113
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								train.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,113 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.optim as optim
 | 
			
		||||
import wandb
 | 
			
		||||
import random
 | 
			
		||||
import os
 | 
			
		||||
import pickle
 | 
			
		||||
from data_processing import delta_encode, delta_decode, save_wav
 | 
			
		||||
from utils import visualize_prediction, plot_delta_distribution
 | 
			
		||||
from bitstream import ArithmeticEncoder
 | 
			
		||||
 | 
			
		||||
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0, num_points=None):
 | 
			
		||||
    compression_ratios = []
 | 
			
		||||
    identical_count = 0
 | 
			
		||||
    all_deltas = []
 | 
			
		||||
 | 
			
		||||
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
    model.to(device)
 | 
			
		||||
 | 
			
		||||
    for file_data in data:
 | 
			
		||||
        file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
 | 
			
		||||
        encoded_data = model(file_data).squeeze(1).cpu().detach().numpy().tolist()
 | 
			
		||||
        encoder.build_model(encoded_data)
 | 
			
		||||
        compressed_data = encoder.encode(encoded_data)
 | 
			
		||||
        decompressed_data = encoder.decode(compressed_data, len(encoded_data))
 | 
			
		||||
        
 | 
			
		||||
        # Check equivalence
 | 
			
		||||
        if use_delta_encoding:
 | 
			
		||||
            decompressed_data = delta_decode(decompressed_data)
 | 
			
		||||
        identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
 | 
			
		||||
        if identical:
 | 
			
		||||
            identical_count += 1
 | 
			
		||||
 | 
			
		||||
        compression_ratio = len(file_data) / len(compressed_data)
 | 
			
		||||
        compression_ratios.append(compression_ratio)
 | 
			
		||||
        
 | 
			
		||||
        # Compute and collect deltas
 | 
			
		||||
        predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(device)).squeeze(1).cpu().detach().numpy().tolist()
 | 
			
		||||
        if use_delta_encoding:
 | 
			
		||||
            predicted_data = delta_decode(predicted_data)
 | 
			
		||||
        delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
 | 
			
		||||
        all_deltas.extend(delta_data)
 | 
			
		||||
        
 | 
			
		||||
        # Visualize prediction vs data vs error
 | 
			
		||||
        visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, num_points)
 | 
			
		||||
 | 
			
		||||
    identical_percentage = (identical_count / len(data)) * 100
 | 
			
		||||
    
 | 
			
		||||
    # Plot delta distribution
 | 
			
		||||
    delta_plot_path = plot_delta_distribution(all_deltas, epoch)
 | 
			
		||||
    wandb.log({"delta_distribution": wandb.Image(delta_plot_path)})
 | 
			
		||||
    
 | 
			
		||||
    return compression_ratios, identical_percentage
 | 
			
		||||
 | 
			
		||||
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path, num_points=None):
 | 
			
		||||
    """Train the model."""
 | 
			
		||||
    wandb.init(project="wav-compression")
 | 
			
		||||
    criterion = nn.MSELoss()
 | 
			
		||||
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
 | 
			
		||||
    best_test_score = float('inf')
 | 
			
		||||
 | 
			
		||||
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
    model.to(device)
 | 
			
		||||
    
 | 
			
		||||
    for epoch in range(epochs):
 | 
			
		||||
        total_loss = 0
 | 
			
		||||
        random.shuffle(train_data)  # Shuffle data for varied batches
 | 
			
		||||
        for i in range(0, len(train_data) - batch_size, batch_size):
 | 
			
		||||
            inputs = torch.tensor(train_data[i:i+batch_size], dtype=torch.float32).unsqueeze(2).to(device)
 | 
			
		||||
            targets = torch.tensor(train_data[i+1:i+batch_size+1], dtype=torch.float32).unsqueeze(2).to(device)
 | 
			
		||||
            outputs = model(inputs)
 | 
			
		||||
            loss = criterion(outputs, targets)
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
            total_loss += loss.item()
 | 
			
		||||
        
 | 
			
		||||
        wandb.log({"epoch": epoch, "loss": total_loss})
 | 
			
		||||
        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
 | 
			
		||||
        
 | 
			
		||||
        if (epoch + 1) % eval_freq == 0:
 | 
			
		||||
            # Evaluate on train and test data
 | 
			
		||||
            train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
 | 
			
		||||
            test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
 | 
			
		||||
            
 | 
			
		||||
            # Log statistics
 | 
			
		||||
            wandb.log({
 | 
			
		||||
                "train_compression_ratio_mean": np.mean(train_compression_ratios),
 | 
			
		||||
                "train_compression_ratio_std": np.std(train_compression_ratios),
 | 
			
		||||
                "train_compression_ratio_min": np.min(train_compression_ratios),
 | 
			
		||||
                "train_compression_ratio_max": np.max(train_compression_ratios),
 | 
			
		||||
                "test_compression_ratio_mean": np.mean(test_compression_ratios),
 | 
			
		||||
                "test_compression_ratio_std": np.std(test_compression_ratios),
 | 
			
		||||
                "test_compression_ratio_min": np.min(test_compression_ratios),
 | 
			
		||||
                "test_compression_ratio_max": np.max(test_compression_ratios),
 | 
			
		||||
                "train_identical_percentage": train_identical_percentage,
 | 
			
		||||
                "test_identical_percentage": test_identical_percentage,
 | 
			
		||||
            })
 | 
			
		||||
            
 | 
			
		||||
            print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
 | 
			
		||||
            print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
 | 
			
		||||
            
 | 
			
		||||
            # Save model and encoder if new highscore on test data
 | 
			
		||||
            test_score = np.mean(test_compression_ratios)
 | 
			
		||||
            if test_score < best_test_score:
 | 
			
		||||
                best_test_score = test_score
 | 
			
		||||
                model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt")
 | 
			
		||||
                encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl")
 | 
			
		||||
                torch.save(model.state_dict(), model_path)
 | 
			
		||||
                with open(encoder_path, 'wb') as f:
 | 
			
		||||
                    pickle.dump(encoder, f)
 | 
			
		||||
                print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}')
 | 
			
		||||
							
								
								
									
										63
									
								
								utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import numpy as np
 | 
			
		||||
import wandb
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
 | 
			
		||||
    """Visualize WAV data using matplotlib."""
 | 
			
		||||
    if num_points:
 | 
			
		||||
        data = data[:num_points]
 | 
			
		||||
    plt.figure(figsize=(10, 4))
 | 
			
		||||
    plt.plot(np.linspace(0, len(data) / sample_rate, num=len(data)), data)
 | 
			
		||||
    plt.title(title)
 | 
			
		||||
    plt.xlabel('Time [s]')
 | 
			
		||||
    plt.ylabel('Amplitude')
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None):
 | 
			
		||||
    """Visualize the true data, predicted data, and deltas."""
 | 
			
		||||
    if num_points:
 | 
			
		||||
        true_data = true_data[:num_points]
 | 
			
		||||
        predicted_data = predicted_data[:num_points]
 | 
			
		||||
        delta_data = delta_data[:num_points]
 | 
			
		||||
 | 
			
		||||
    plt.figure(figsize=(15, 5))
 | 
			
		||||
 | 
			
		||||
    plt.subplot(3, 1, 1)
 | 
			
		||||
    plt.plot(true_data, label='True Data')
 | 
			
		||||
    plt.title('True Data')
 | 
			
		||||
    plt.xlabel('Sample')
 | 
			
		||||
    plt.ylabel('Amplitude')
 | 
			
		||||
 | 
			
		||||
    plt.subplot(3, 1, 2)
 | 
			
		||||
    plt.plot(predicted_data, label='Predicted Data', color='orange')
 | 
			
		||||
    plt.title('Predicted Data')
 | 
			
		||||
    plt.xlabel('Sample')
 | 
			
		||||
    plt.ylabel('Amplitude')
 | 
			
		||||
 | 
			
		||||
    plt.subplot(3, 1, 3)
 | 
			
		||||
    plt.plot(delta_data, label='Delta', color='red')
 | 
			
		||||
    plt.title('Delta')
 | 
			
		||||
    plt.xlabel('Sample')
 | 
			
		||||
    plt.ylabel('Amplitude')
 | 
			
		||||
 | 
			
		||||
    plt.tight_layout()
 | 
			
		||||
    tmp_dir = os.getenv('TMPDIR', '/tmp')
 | 
			
		||||
    file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
 | 
			
		||||
    plt.savefig(file_path)
 | 
			
		||||
    plt.close()
 | 
			
		||||
    wandb.log({"Prediction vs True Data": wandb.Image(file_path)})
 | 
			
		||||
 | 
			
		||||
def plot_delta_distribution(deltas, epoch):
 | 
			
		||||
    """Plot the distribution of deltas."""
 | 
			
		||||
    plt.figure(figsize=(10, 6))
 | 
			
		||||
    plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g')
 | 
			
		||||
    plt.title(f'Delta Distribution at Epoch {epoch}')
 | 
			
		||||
    plt.xlabel('Delta')
 | 
			
		||||
    plt.ylabel('Density')
 | 
			
		||||
    plt.grid(True)
 | 
			
		||||
    tmp_dir = os.getenv('TMPDIR', '/tmp')
 | 
			
		||||
    file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png')
 | 
			
		||||
    plt.savefig(file_path)
 | 
			
		||||
    plt.close()
 | 
			
		||||
    return file_path
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user