initial commit
This commit is contained in:
commit
0c5f888d75
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
data
|
||||
data.zip
|
||||
__pycache__
|
||||
.venv
|
||||
wandb
|
||||
slurm_log
|
||||
job_hist.log
|
||||
models
|
||||
Xvfb.log
|
||||
profiler
|
76
README.md
Normal file
76
README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# Spikey
|
||||
|
||||
This repository contains a solution for the [Neuralink Compression Challenge](https://content.neuralink.com/compression-challenge/README.html). The challenge involves compressing raw electrode recordings from a Neuralink implant. These recordings are taken from the motor cortex of a non-human primate while playing a video game.
|
||||
|
||||
## Challenge Overview
|
||||
|
||||
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
|
||||
|
||||
## Installation
|
||||
|
||||
To install the necessary dependencies, create a virtual environment and install the requirements:
|
||||
|
||||
```bash
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Configuration
|
||||
|
||||
The configuration for training and evaluation is specified in a YAML file. Below is an example configuration:
|
||||
|
||||
```yaml
|
||||
name: Test
|
||||
|
||||
preprocessing:
|
||||
use_delta_encoding: true # Whether to use delta encoding.
|
||||
|
||||
predictor:
|
||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
||||
input_size: 1 # Input size for the LSTM predictor.
|
||||
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
|
||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
||||
|
||||
training:
|
||||
epochs: 10 # Number of training epochs.
|
||||
batch_size: 32 # Batch size for training.
|
||||
learning_rate: 0.001 # Learning rate for the optimizer.
|
||||
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
||||
save_path: models # Directory to save the best model and encoder.
|
||||
num_points: 1000 # Number of data points to visualize.
|
||||
|
||||
bitstream_encoding:
|
||||
type: arithmetic # Use arithmetic encoding.
|
||||
|
||||
data:
|
||||
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
||||
directory: data # Directory to extract and store the dataset.
|
||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
||||
```
|
||||
|
||||
### Running the Code
|
||||
|
||||
To train the model and compress/decompress WAV files, use the CLI provided:
|
||||
|
||||
```bash
|
||||
python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin
|
||||
python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
Requires Slate, which is not currently publicaly avaible. Install via (requires repo access)
|
||||
|
||||
```bash
|
||||
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
|
||||
```
|
||||
|
||||
To train the model, run:
|
||||
|
||||
```bash
|
||||
python main.py config.yaml Test
|
||||
```
|
35
bitstream.py
Normal file
35
bitstream.py
Normal file
@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from arithmetic_compressor import AECompressor
|
||||
from arithmetic_compressor.models import StaticModel
|
||||
|
||||
class BaseEncoder(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, data):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded_data, num_symbols):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_model(self, data):
|
||||
pass
|
||||
|
||||
class ArithmeticEncoder(BaseEncoder):
|
||||
def encode(self, data):
|
||||
if not hasattr(self, 'model'):
|
||||
raise ValueError("Model not built. Call build_model(data) before encoding.")
|
||||
coder = AECompressor(self.model)
|
||||
compressed_data = coder.compress(data)
|
||||
return compressed_data
|
||||
|
||||
def decode(self, encoded_data, num_symbols):
|
||||
coder = AECompressor(self.model)
|
||||
decoded_data = coder.decompress(encoded_data, num_symbols)
|
||||
return decoded_data
|
||||
|
||||
def build_model(self, data):
|
||||
symbol_counts = {symbol: data.count(symbol) for symbol in set(data)}
|
||||
total_symbols = sum(symbol_counts.values())
|
||||
probabilities = {symbol: count / total_symbols for symbol, count in symbol_counts.items()}
|
||||
self.model = StaticModel(probabilities)
|
50
cli.py
Normal file
50
cli.py
Normal file
@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
import yaml
|
||||
import os
|
||||
import torch
|
||||
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
|
||||
from main import SpikeRunner
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
|
||||
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
|
||||
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
|
||||
parser.add_argument('--input_file', help="Path to the input WAV file")
|
||||
parser.add_argument('--output_file', help="Path to the output file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = load_config(args.config)
|
||||
|
||||
spike_runner = SpikeRunner(None, config)
|
||||
spike_runner.setup('CLI')
|
||||
|
||||
if args.action == 'compress':
|
||||
data = load_all_wavs(args.input_file)
|
||||
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
||||
data = [delta_encode(d) for d in data]
|
||||
|
||||
spike_runner.encoder.build_model(data)
|
||||
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
|
||||
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
|
||||
|
||||
with open(args.output_file, 'wb') as f:
|
||||
for cd in compressed_data:
|
||||
f.write(bytearray(cd))
|
||||
|
||||
elif args.action == 'decompress':
|
||||
with open(args.input_file, 'rb') as f:
|
||||
compressed_data = list(f.read())
|
||||
|
||||
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
|
||||
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
|
||||
decoded_data = [delta_decode(dd) for dd in decoded_data]
|
||||
|
||||
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
67
config.yaml
Normal file
67
config.yaml
Normal file
@ -0,0 +1,67 @@
|
||||
name: DEFAULT
|
||||
project: Spikey
|
||||
|
||||
slurm:
|
||||
name: 'Spikey_{config[name]}'
|
||||
partitions:
|
||||
- single
|
||||
standard_output: ./reports/slurm/out_%A_%a.log
|
||||
standard_error: ./reports/slurm/err_%A_%a.log
|
||||
num_parallel_jobs: 50
|
||||
cpus_per_task: 4
|
||||
memory_per_cpu: 1000
|
||||
time_limit: 1440 # in minutes
|
||||
ntasks: 1
|
||||
venv: '.venv/bin/activate'
|
||||
sh_lines:
|
||||
- 'mkdir -p {tmp}/wandb'
|
||||
- 'mkdir -p {tmp}/local_pycache'
|
||||
- 'export PYTHONPYCACHEPREFIX={tmp}/local_pycache'
|
||||
|
||||
runner: spikey
|
||||
|
||||
scheduler:
|
||||
reps_per_version: 1
|
||||
agents_per_job: 1
|
||||
reps_per_agent: 1
|
||||
|
||||
wandb:
|
||||
project: '{config[project]}'
|
||||
group: '{config[name]}'
|
||||
job_type: '{delta_desc}'
|
||||
name: '{job_id}_{task_id}:{run_id}:{rand}={config[name]}_{delta_desc}'
|
||||
tags:
|
||||
- '{config[env][name]}'
|
||||
- '{config[algo][name]}'
|
||||
sync_tensorboard: False
|
||||
monitor_gym: False
|
||||
save_code: False
|
||||
|
||||
---
|
||||
name: Test
|
||||
|
||||
preprocessing:
|
||||
use_delta_encoding: true # Whether to use delta encoding.
|
||||
|
||||
predictor:
|
||||
type: lstm # Options: 'lstm', 'fixed_input_nn'
|
||||
input_size: 1 # Input size for the LSTM predictor.
|
||||
hidden_size: 128 # Hidden size for the LSTM or Fixed Input NN predictor.
|
||||
num_layers: 2 # Number of layers for the LSTM predictor.
|
||||
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
|
||||
|
||||
training:
|
||||
epochs: 10 # Number of training epochs.
|
||||
batch_size: 32 # Batch size for training.
|
||||
learning_rate: 0.001 # Learning rate for the optimizer.
|
||||
eval_freq: 2 # Frequency of evaluation during training (in epochs).
|
||||
save_path: models # Directory to save the best model and encoder.
|
||||
num_points: 1000 # Number of data points to visualize
|
||||
|
||||
bitstream_encoding:
|
||||
type: arithmetic # Use arithmetic encoding.
|
||||
|
||||
data:
|
||||
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
|
||||
directory: data # Directory to extract and store the dataset.
|
||||
split_ratio: 0.8 # Ratio to split the data into train and test sets.
|
46
data_processing.py
Normal file
46
data_processing.py
Normal file
@ -0,0 +1,46 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
def download_and_extract_data(url, data_dir):
|
||||
if not os.path.exists(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
zip_path = os.path.join(data_dir, 'data.zip')
|
||||
urllib.request.urlretrieve(url, zip_path)
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(data_dir)
|
||||
os.remove(zip_path)
|
||||
|
||||
def load_wav(file_path):
|
||||
"""Load WAV file and return sample rate and data."""
|
||||
sample_rate, data = wavfile.read(file_path)
|
||||
return sample_rate, data
|
||||
|
||||
def load_all_wavs(data_dir):
|
||||
"""Load all WAV files in the given directory."""
|
||||
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
|
||||
all_data = []
|
||||
for file_path in wav_files:
|
||||
_, data = load_wav(file_path)
|
||||
all_data.append(data)
|
||||
return all_data
|
||||
|
||||
def save_wav(file_path, sample_rate, data):
|
||||
"""Save data to a WAV file."""
|
||||
wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32))
|
||||
|
||||
def delta_encode(data):
|
||||
"""Apply delta encoding to the data."""
|
||||
deltas = [data[0]]
|
||||
for i in range(1, len(data)):
|
||||
deltas.append(data[i] - data[i - 1])
|
||||
return deltas
|
||||
|
||||
def delta_decode(deltas):
|
||||
"""Decode delta encoded data."""
|
||||
data = [deltas[0]]
|
||||
for i in range(1, len(deltas)):
|
||||
data.append(data[-1] + deltas[i])
|
||||
return data
|
69
main.py
Normal file
69
main.py
Normal file
@ -0,0 +1,69 @@
|
||||
import yaml
|
||||
from slate import Slate, Slate_Runner
|
||||
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
|
||||
from model import LSTMPredictor, FixedInputNNPredictor
|
||||
from train import train_model
|
||||
from bitstream import ArithmeticEncoder
|
||||
|
||||
class SpikeRunner(Slate_Runner):
|
||||
def setup(self, name):
|
||||
self.name = name
|
||||
slate, config = self.slate, self.config
|
||||
|
||||
# Consume config sections
|
||||
preprocessing_config = slate.consume(config, 'preprocessing', expand=True)
|
||||
predictor_config = slate.consume(config, 'predictor', expand=True)
|
||||
training_config = slate.consume(config, 'training', expand=True)
|
||||
bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True)
|
||||
data_config = slate.consume(config, 'data', expand=True)
|
||||
|
||||
# Data setup
|
||||
data_url = slate.consume(data_config, 'url')
|
||||
data_dir = slate.consume(data_config, 'directory')
|
||||
download_and_extract_data(data_url, data_dir)
|
||||
all_data = load_all_wavs(data_dir)
|
||||
|
||||
if slate.consume(preprocessing_config, 'use_delta_encoding'):
|
||||
all_data = [delta_encode(d) for d in all_data]
|
||||
|
||||
# Split data into train and test sets
|
||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.8)
|
||||
split_idx = int(len(all_data) * split_ratio)
|
||||
self.train_data = all_data[:split_idx]
|
||||
self.test_data = all_data[split_idx:]
|
||||
|
||||
# Model setup
|
||||
self.model = self.get_model(predictor_config)
|
||||
self.encoder = self.get_encoder(bitstream_config)
|
||||
self.epochs = slate.consume(training_config, 'epochs')
|
||||
self.batch_size = slate.consume(training_config, 'batch_size')
|
||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
|
||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||
self.save_path = slate.consume(training_config, 'save_path', 'models')
|
||||
|
||||
def get_model(self, config):
|
||||
model_type = self.slate.consume(config, 'type')
|
||||
if model_type == 'lstm':
|
||||
return LSTMPredictor(
|
||||
input_size=self.slate.consume(config, 'input_size'),
|
||||
hidden_size=self.slate.consume(config, 'hidden_size'),
|
||||
num_layers=self.slate.consume(config, 'num_layers')
|
||||
)
|
||||
elif model_type == 'fixed_input_nn':
|
||||
return FixedInputNNPredictor(
|
||||
input_size=self.slate.consume(config, 'fixed_input_size'),
|
||||
hidden_size=self.slate.consume(config, 'hidden_size')
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
def get_encoder(self, config):
|
||||
return ArithmeticEncoder()
|
||||
|
||||
def run(self, run, forceNoProfile=False):
|
||||
train_model(self.model, self.train_data, self.test_data, self.epochs, self.batch_size, self.learning_rate, self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
slate = Slate({'spikey': SpikeRunner})
|
||||
slate.from_args()
|
98
model.py
Normal file
98
model.py
Normal file
@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
class BaseModel(ABC, nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, data):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded_data):
|
||||
pass
|
||||
|
||||
class LSTMPredictor(BaseModel):
|
||||
def __init__(self, input_size, hidden_size, num_layers):
|
||||
super(LSTMPredictor, self).__init__()
|
||||
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
||||
self.fc = nn.Linear(hidden_size, 1)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, x):
|
||||
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
||||
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
||||
out, _ = self.rnn(x, (h0, c0))
|
||||
out = self.fc(out)
|
||||
return out
|
||||
|
||||
def encode(self, data):
|
||||
self.eval()
|
||||
encoded_data = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(data) - 1):
|
||||
context = torch.tensor(data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
|
||||
prediction = self.forward(context).item()
|
||||
delta = data[i] - prediction
|
||||
encoded_data.append(delta)
|
||||
|
||||
return encoded_data
|
||||
|
||||
def decode(self, encoded_data):
|
||||
self.eval()
|
||||
decoded_data = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(encoded_data)):
|
||||
context = torch.tensor(decoded_data[max(0, i - self.hidden_size):i]).view(1, -1, 1).float()
|
||||
prediction = self.forward(context).item()
|
||||
decoded_data.append(prediction + encoded_data[i])
|
||||
|
||||
return decoded_data
|
||||
|
||||
class FixedInputNNPredictor(BaseModel):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(FixedInputNNPredictor, self).__init__()
|
||||
self.fc1 = nn.Linear(input_size, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_size, 1)
|
||||
self.input_size = input_size
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def encode(self, data):
|
||||
self.eval()
|
||||
encoded_data = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(data) - self.input_size):
|
||||
context = torch.tensor(data[i:i + self.input_size]).view(1, -1).float()
|
||||
prediction = self.forward(context).item()
|
||||
delta = data[i + self.input_size] - prediction
|
||||
encoded_data.append(delta)
|
||||
|
||||
return encoded_data
|
||||
|
||||
def decode(self, encoded_data):
|
||||
self.eval()
|
||||
decoded_data = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(len(encoded_data)):
|
||||
context = torch.tensor(decoded_data[max(0, i - self.input_size):i]).view(1, -1).float()
|
||||
prediction = self.forward(context).item()
|
||||
decoded_data.append(prediction + encoded_data[i])
|
||||
|
||||
return decoded_data
|
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
torch
|
||||
numpy
|
||||
scipy
|
||||
matplotlib
|
||||
wandb
|
||||
pyyaml
|
113
train.py
Normal file
113
train.py
Normal file
@ -0,0 +1,113 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
import random
|
||||
import os
|
||||
import pickle
|
||||
from data_processing import delta_encode, delta_decode, save_wav
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from bitstream import ArithmeticEncoder
|
||||
|
||||
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0, num_points=None):
|
||||
compression_ratios = []
|
||||
identical_count = 0
|
||||
all_deltas = []
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
for file_data in data:
|
||||
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(device)
|
||||
encoded_data = model(file_data).squeeze(1).cpu().detach().numpy().tolist()
|
||||
encoder.build_model(encoded_data)
|
||||
compressed_data = encoder.encode(encoded_data)
|
||||
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
|
||||
|
||||
# Check equivalence
|
||||
if use_delta_encoding:
|
||||
decompressed_data = delta_decode(decompressed_data)
|
||||
identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
|
||||
if identical:
|
||||
identical_count += 1
|
||||
|
||||
compression_ratio = len(file_data) / len(compressed_data)
|
||||
compression_ratios.append(compression_ratio)
|
||||
|
||||
# Compute and collect deltas
|
||||
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(device)).squeeze(1).cpu().detach().numpy().tolist()
|
||||
if use_delta_encoding:
|
||||
predicted_data = delta_decode(predicted_data)
|
||||
delta_data = [file_data[i].item() - predicted_data[i] for i in range(len(file_data))]
|
||||
all_deltas.extend(delta_data)
|
||||
|
||||
# Visualize prediction vs data vs error
|
||||
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, num_points)
|
||||
|
||||
identical_percentage = (identical_count / len(data)) * 100
|
||||
|
||||
# Plot delta distribution
|
||||
delta_plot_path = plot_delta_distribution(all_deltas, epoch)
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)})
|
||||
|
||||
return compression_ratios, identical_percentage
|
||||
|
||||
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path, num_points=None):
|
||||
"""Train the model."""
|
||||
wandb.init(project="wav-compression")
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
best_test_score = float('inf')
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0
|
||||
random.shuffle(train_data) # Shuffle data for varied batches
|
||||
for i in range(0, len(train_data) - batch_size, batch_size):
|
||||
inputs = torch.tensor(train_data[i:i+batch_size], dtype=torch.float32).unsqueeze(2).to(device)
|
||||
targets = torch.tensor(train_data[i+1:i+batch_size+1], dtype=torch.float32).unsqueeze(2).to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
wandb.log({"epoch": epoch, "loss": total_loss})
|
||||
print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
|
||||
|
||||
if (epoch + 1) % eval_freq == 0:
|
||||
# Evaluate on train and test data
|
||||
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
|
||||
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch, num_points=num_points)
|
||||
|
||||
# Log statistics
|
||||
wandb.log({
|
||||
"train_compression_ratio_mean": np.mean(train_compression_ratios),
|
||||
"train_compression_ratio_std": np.std(train_compression_ratios),
|
||||
"train_compression_ratio_min": np.min(train_compression_ratios),
|
||||
"train_compression_ratio_max": np.max(train_compression_ratios),
|
||||
"test_compression_ratio_mean": np.mean(test_compression_ratios),
|
||||
"test_compression_ratio_std": np.std(test_compression_ratios),
|
||||
"test_compression_ratio_min": np.min(test_compression_ratios),
|
||||
"test_compression_ratio_max": np.max(test_compression_ratios),
|
||||
"train_identical_percentage": train_identical_percentage,
|
||||
"test_identical_percentage": test_identical_percentage,
|
||||
})
|
||||
|
||||
print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
|
||||
print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
|
||||
|
||||
# Save model and encoder if new highscore on test data
|
||||
test_score = np.mean(test_compression_ratios)
|
||||
if test_score < best_test_score:
|
||||
best_test_score = test_score
|
||||
model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt")
|
||||
encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
with open(encoder_path, 'wb') as f:
|
||||
pickle.dump(encoder, f)
|
||||
print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}')
|
63
utils.py
Normal file
63
utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import wandb
|
||||
import os
|
||||
|
||||
def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None):
|
||||
"""Visualize WAV data using matplotlib."""
|
||||
if num_points:
|
||||
data = data[:num_points]
|
||||
plt.figure(figsize=(10, 4))
|
||||
plt.plot(np.linspace(0, len(data) / sample_rate, num=len(data)), data)
|
||||
plt.title(title)
|
||||
plt.xlabel('Time [s]')
|
||||
plt.ylabel('Amplitude')
|
||||
plt.show()
|
||||
|
||||
def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None):
|
||||
"""Visualize the true data, predicted data, and deltas."""
|
||||
if num_points:
|
||||
true_data = true_data[:num_points]
|
||||
predicted_data = predicted_data[:num_points]
|
||||
delta_data = delta_data[:num_points]
|
||||
|
||||
plt.figure(figsize=(15, 5))
|
||||
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.plot(true_data, label='True Data')
|
||||
plt.title('True Data')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.plot(predicted_data, label='Predicted Data', color='orange')
|
||||
plt.title('Predicted Data')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
plt.plot(delta_data, label='Delta', color='red')
|
||||
plt.title('Delta')
|
||||
plt.xlabel('Sample')
|
||||
plt.ylabel('Amplitude')
|
||||
|
||||
plt.tight_layout()
|
||||
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
||||
file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close()
|
||||
wandb.log({"Prediction vs True Data": wandb.Image(file_path)})
|
||||
|
||||
def plot_delta_distribution(deltas, epoch):
|
||||
"""Plot the distribution of deltas."""
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(deltas, bins=100, density=True, alpha=0.6, color='g')
|
||||
plt.title(f'Delta Distribution at Epoch {epoch}')
|
||||
plt.xlabel('Delta')
|
||||
plt.ylabel('Density')
|
||||
plt.grid(True)
|
||||
tmp_dir = os.getenv('TMPDIR', '/tmp')
|
||||
file_path = os.path.join(tmp_dir, f'delta_distribution_epoch_{epoch}_{np.random.randint(1e6)}.png')
|
||||
plt.savefig(file_path)
|
||||
plt.close()
|
||||
return file_path
|
Loading…
Reference in New Issue
Block a user