Changed everything

This commit is contained in:
Dominik Moritz Roth 2024-05-25 17:31:08 +02:00
parent 73b306dc50
commit 29946baff0
10 changed files with 1285447 additions and 390 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@ job_hist.log
models
Xvfb.log
profiler
.ipynb_checkpoints/

View File

@ -6,6 +6,24 @@ This repository contains a solution for the [Neuralink Compression Challenge](ht
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio).
## Data Analysis
The `analysis.ipynb` notebook contains a detailed analysis of the data. We found that there is sometimes significant cross-correlation between the different leads, so we find it vital to use this information for better compression. This cross-correlation allows us to improve the accuracy of our predictions and reduce the overall amount of data that needs to be transmitted. As part of the analysis, we also note that achieving a 200x compression ratio is highly unlikely to be possible and is also nonsensical, a very close reproduction is sufficient.
## Compression Details
The solution leverages three neural network models to achieve effective compression:
1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) based on the configuration.
2. **MiddleOut (Message Passer)**: For each lead, this module looks up the `n` most correlated leads and uses their latent representations along with their correlation values to generate a new latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, correlation) -> new_latent and then averaging over all new_latent values to get the final representation.
3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training.
By accurately predicting the next timestep, the delta (difference) between the actual value and the predicted value is minimized. Small deltas mean that fewer bits are needed to store these values, which are then efficiently encoded using the bitstream encoder.
The neural networks used in this solution are tiny, making it possible to meet the latency and power requirements if implemented more efficiently.
## Installation
To install the necessary dependencies, create a virtual environment and install the requirements:
@ -18,18 +36,9 @@ pip install -r requirements.txt
## Usage
### Running the Code
To train the model and compress/decompress WAV files, use the CLI provided:
```bash
python cli.py compress --config config.yaml --input_file path/to/input.wav --output_file path/to/output.bin
python cli.py decompress --config config.yaml --input_file path/to/output.bin --output_file path/to/output.wav
```
### Training
Requires Slate, which is not currently publicaly avaible. Install via (requires repo access)
Requires Slate, which is not currently publicly available. Install via (requires repo access):
```bash
pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
@ -38,5 +47,6 @@ pip install -e git+ssh://git@dominik-roth.eu/dodox/Slate.git#egg=slate
To train the model, run:
```bash
python main.py config.yaml Test
python main.py <config_file.yaml> <exp_name>
```

1285119
analysis.ipynb Normal file

File diff suppressed because one or more lines are too long

50
cli.py
View File

@ -1,50 +0,0 @@
import argparse
import yaml
import os
import torch
from data_processing import download_and_extract_data, load_all_wavs, save_wav, delta_encode, delta_decode
from main import SpikeRunner
def load_config(config_path):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
def main():
parser = argparse.ArgumentParser(description="WAV Compression with Neural Networks")
parser.add_argument('action', choices=['compress', 'decompress'], help="Action to perform")
parser.add_argument('--config', default='config.yaml', help="Path to the configuration file")
parser.add_argument('--input_file', help="Path to the input WAV file")
parser.add_argument('--output_file', help="Path to the output file")
args = parser.parse_args()
config = load_config(args.config)
spike_runner = SpikeRunner(None, config)
spike_runner.setup('CLI')
if args.action == 'compress':
data = load_all_wavs(args.input_file)
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
data = [delta_encode(d) for d in data]
spike_runner.encoder.build_model(data)
encoded_data = [spike_runner.model(torch.tensor(d, dtype=torch.float32).unsqueeze(0)).squeeze(0).detach().numpy().tolist() for d in data]
compressed_data = [spike_runner.encoder.encode(ed) for ed in encoded_data]
with open(args.output_file, 'wb') as f:
for cd in compressed_data:
f.write(bytearray(cd))
elif args.action == 'decompress':
with open(args.input_file, 'rb') as f:
compressed_data = list(f.read())
decoded_data = [spike_runner.encoder.decode(cd, len(cd)) for cd in compressed_data]
if spike_runner.slate.consume(config['preprocessing'], 'use_delta_encoding'):
decoded_data = [delta_decode(dd) for dd in decoded_data]
save_wav(args.output_file, 19531, decoded_data) # Assuming 19531 Hz sample rate
if __name__ == "__main__":
main()

View File

@ -41,15 +41,22 @@ wandb:
name: Test
import: $
preprocessing:
use_delta_encoding: false # Whether to use delta encoding.
latent_projector:
type: fc # Options: 'fc', 'rnn'
input_size: 50 # Input size for the Latent Projector (length of snippets).
latent_size: 8 # Size of the latent representation before message passing.
layer_shapes: [16, 32] # List of layer sizes for the latent projector (if type is 'fc').
activations: ['relu', 'relu'] # Activation functions for the latent projector layers (if type is 'fc').
rnn_hidden_size: 32 # Hidden size for the RNN projector (if type is 'rnn').
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
middle_out:
output_size: 16 # Size of the latent representation after message passing.
num_peers: 3 # Number of most correlated peers to consider.
predictor:
type: lstm # Options: 'lstm', 'fixed_input_nn'
input_size: 1 # Input size for the LSTM predictor.
hidden_size: 8 # 16 # Hidden size for the LSTM or Fixed Input NN predictor.
num_layers: 2 # Number of layers for the LSTM predictor.
fixed_input_size: 10 # Input size for the Fixed Input NN predictor. Only used if type is 'fixed_input_nn'.
layer_shapes: [32, 16] # List of layer sizes for the predictor.
activations: ['relu', 'relu'] # Activation functions for the predictor layers.
training:
epochs: 128 # Number of training epochs.
@ -59,6 +66,9 @@ training:
save_path: models # Directory to save the best model and encoder.
num_points: 1000 # Number of data points to visualize
evaluation:
full_compression: false # Perform full compression during evaluation
bitstream_encoding:
type: identity # Options: 'arithmetic', 'no_compression', 'bzip2'
@ -66,14 +76,7 @@ data:
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
directory: data # Directory to extract and store the dataset.
split_ratio: 0.8 # Ratio to split the data into train and test sets.
cut_length: None # Optional length to cut sequences to.
profiler:
enable: false
ablative:
training:
learning_rate: [0.01, 0.0001, 0.00001]
batch_size: [4, 16]
predictor:
hidden_size: [4, 16]
num_layers: [1, 3]
enable: false

View File

@ -14,33 +14,33 @@ def download_and_extract_data(url, data_dir):
os.remove(zip_path)
def load_wav(file_path):
"""Load WAV file and return sample rate and data."""
sample_rate, data = wavfile.read(file_path)
return sample_rate, data
def load_all_wavs(data_dir):
"""Load all WAV files in the given directory."""
def load_all_wavs(data_dir, cut_length=None):
wav_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
all_data = []
for file_path in wav_files:
_, data = load_wav(file_path)
if cut_length:
data = data[:cut_length]
all_data.append(data)
return all_data
def save_wav(file_path, sample_rate, data):
"""Save data to a WAV file."""
wavfile.write(file_path, sample_rate, np.asarray(data, dtype=np.float32))
def compute_correlation_matrix(data):
num_leads = len(data)
corr_matrix = np.zeros((num_leads, num_leads))
for i in range(num_leads):
for j in range(num_leads):
if i != j:
corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
return corr_matrix
def delta_encode(data):
"""Apply delta encoding to the data."""
deltas = [data[0]]
for i in range(1, len(data)):
deltas.append(data[i] - data[i - 1])
return np.array(deltas)
def delta_decode(deltas):
"""Decode delta encoded data."""
data = [deltas[0]]
for i in range(1, len(deltas)):
data.append(data[-1] + deltas[i])
return np.array(data)
def split_data_by_time(data, split_ratio=0.5):
train_data = []
test_data = []
for lead in data:
split_idx = int(len(lead) * split_ratio)
train_data.append(lead[:split_idx])
test_data.append(lead[split_idx:])
return train_data, test_data

255
main.py
View File

@ -1,77 +1,71 @@
from slate import Slate, Slate_Runner
import os
import torch
import torch.nn as nn
import numpy as np
import random
from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
import wandb
from pycallgraph import PyCallGraph
from pycallgraph.output import GraphvizOutput
import slate
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from data_processing import download_and_extract_data, load_all_wavs, delta_encode
from model import LSTMPredictor, FixedInputNNPredictor
from train import train_model
from bitstream import ArithmeticEncoder, IdentityEncoder, Bzip2Encoder
class SpikeRunner:
def __init__(self, config):
self.config = config
self.name = slate.consume(config, 'name', default='Test')
class SpikeRunner(Slate_Runner):
def setup(self, name):
self.name = name
slate, config = self.slate, self.config
# Consume config sections
preprocessing_config = slate.consume(config, 'preprocessing', expand=True)
predictor_config = slate.consume(config, 'predictor', expand=True)
training_config = slate.consume(config, 'training', expand=True)
bitstream_config = slate.consume(config, 'bitstream_encoding', expand=True)
data_config = slate.consume(config, 'data', expand=True)
# Data setup
data_url = slate.consume(data_config, 'url')
data_dir = slate.consume(data_config, 'directory')
cut_length = slate.consume(data_config, 'cut_length', None)
download_and_extract_data(data_url, data_dir)
all_data = load_all_wavs(data_dir)
all_data = load_all_wavs(data_dir, cut_length)
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
# Compute correlation matrix
self.correlation_matrix = compute_correlation_matrix(self.train_data)
# Model setup
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
if latent_projector_type == 'fc':
self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device)
# Training parameters
self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size')
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.use_delta_encoding = slate.consume(preprocessing_config, 'use_delta_encoding')
self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path', 'models')
self.save_path = slate.consume(training_config, 'save_path')
if self.use_delta_encoding:
all_data = [delta_encode(d) for d in all_data]
# Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
# Split data into train and test sets
split_ratio = slate.consume(data_config, 'split_ratio', 0.8)
split_idx = int(len(all_data) * split_ratio)
self.train_data = all_data[:split_idx]
self.test_data = all_data[split_idx:]
# Model setup
self.model = self.get_model(predictor_config)
self.encoder = self.get_encoder(bitstream_config)
# Bitstream encoding
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
if bitstream_type == 'identity':
self.encoder = IdentityEncoder()
elif bitstream_type == 'arithmetic':
self.encoder = ArithmeticEncoder()
elif bitstream_type == 'bzip2':
self.encoder = Bzip2Encoder()
def get_model(self, config):
model_type = slate.consume(config, 'type')
if model_type == 'lstm':
return LSTMPredictor(
input_size=slate.consume(config, 'input_size'),
hidden_size=slate.consume(config, 'hidden_size'),
num_layers=slate.consume(config, 'num_layers')
)
elif model_type == 'fixed_input_nn':
return FixedInputNNPredictor(
input_size=slate.consume(config, 'fixed_input_size'),
hidden_size=slate.consume(config, 'hidden_size')
)
else:
raise ValueError(f"Unknown model type: {model_type}")
def get_encoder(self, config):
encoder_type = slate.consume(config, 'type')
if encoder_type == 'arithmetic':
return ArithmeticEncoder()
elif encoder_type == 'identity':
return IdentityEncoder()
elif encoder_type == 'bzip2':
return Bzip2Encoder()
else:
raise ValueError(f"Unknown encoder type: {encoder_type}")
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss()
def run(self, run, forceNoProfile=False):
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
@ -80,12 +74,147 @@ class SpikeRunner(Slate_Runner):
self.run(run, forceNoProfile=True)
print('{PROFILER DONE}')
return
self.train_model()
train_model(
self.model, self.train_data, self.test_data,
self.epochs, self.batch_size, self.learning_rate,
self.use_delta_encoding, self.encoder, self.eval_freq, self.save_path
)
def train_model(self):
max_length = max([len(seq) for seq in self.train_data])
print(f"Max sequence length: {max_length}")
best_test_score = float('inf')
for epoch in range(self.epochs):
total_loss = 0
random.shuffle(self.train_data)
for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size):
batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data])
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device)
batch_loss = 0
for lead_idx in range(len(inputs)):
lead_data = inputs[lead_idx]
latents = self.projector(lead_data)
for t in range(latents.shape[0]):
my_latent = latents[t]
peer_latents = []
peer_correlations = []
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
peer_latents = torch.stack(peer_latents).to(device)
peer_correlations = torch.stack(peer_correlations).to(device)
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
prediction = self.predictor(new_latent)
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
loss = self.criterion(prediction, target)
batch_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += batch_loss
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch+1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
def evaluate_model(self, epoch):
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()
total_loss = 0
all_true = []
all_predicted = []
all_deltas = []
compression_ratios = []
exact_matches = 0
total_sequences = 0
with torch.no_grad():
for lead_idx in range(len(self.test_data)):
lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device)
latents = self.projector(lead_data)
true_data = []
predicted_data = []
delta_data = []
for t in range(latents.shape[0]):
my_latent = latents[t]
peer_latents = []
peer_correlations = []
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
peer_latents = torch.stack(peer_latents).to(device)
peer_correlations = torch.stack(peer_correlations).to(device)
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
prediction = self.predictor(new_latent)
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
loss = self.criterion(prediction, target)
total_loss += loss.item()
true_data.append(target.cpu().numpy())
predicted_data.append(prediction.cpu().numpy())
delta_data.append((target - prediction).cpu().numpy())
all_true.append(true_data)
all_predicted.append(predicted_data)
all_deltas.append(delta_data)
if self.full_compression:
self.encoder.build_model(latents.cpu().numpy())
compressed_data = self.encoder.encode(latents.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(latents))
compression_ratio = len(latents) / len(compressed_data)
compression_ratios.append(compression_ratio)
# Check if decompressed data matches the original data
if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1
total_sequences += 1
visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch)
avg_loss = total_loss / len(self.test_data)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
if self.full_compression:
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
exact_match_percentage = (exact_matches / total_sequences) * 100
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
return avg_loss
def save_models(self, epoch):
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.")
if __name__ == '__main__':
slate = Slate({'spikey': SpikeRunner})

110
model.py
View File

@ -1,110 +0,0 @@
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def encode(self, data):
pass
@abstractmethod
def decode(self, encoded_data):
pass
class LSTMPredictor(BaseModel):
def __init__(self, input_size, hidden_size, num_layers):
super(LSTMPredictor, self).__init__()
self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
self.hidden_size = hidden_size
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, x):
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size).to(self.device)
out, _ = self.rnn(x, (h0, c0))
out = self.fc(out)
return out
def encode(self, data):
self.eval()
encoded_data = []
context_size = self.hidden_size # Define an appropriate context size
with torch.no_grad():
for i in range(len(data) - 1):
context = torch.tensor(data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
if context.size(1) == 0: # Handle empty context
continue
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
delta = data[i] - prediction
encoded_data.append(delta)
return encoded_data
def decode(self, encoded_data):
self.eval()
decoded_data = []
context_size = self.hidden_size # Define an appropriate context size
with torch.no_grad():
for i in range(len(encoded_data)):
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1, 1).to(self.device)
if context.size(1) == 0: # Handle empty context
continue
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
decoded_data.append(prediction + encoded_data[i])
return decoded_data
class FixedInputNNPredictor(BaseModel):
def __init__(self, input_size, hidden_size):
super(FixedInputNNPredictor, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, 1)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def encode(self, data):
self.eval()
encoded_data = []
context_size = self.fc1.in_features # Define an appropriate context size
with torch.no_grad():
for i in range(len(data) - context_size):
context = torch.tensor(data[i:i + context_size]).reshape(1, -1).to(self.device)
if context.size(1) == 0: # Handle empty context
continue
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
delta = data[i + context_size] - prediction
encoded_data.append(delta)
return encoded_data
def decode(self, encoded_data):
self.eval()
decoded_data = []
context_size = self.fc1.in_features # Define an appropriate context size
with torch.no_grad():
for i in range(len(encoded_data)):
context = torch.tensor(decoded_data[max(0, i - context_size):i]).reshape(1, -1).to(self.device)
if context.size(1) == 0: # Handle empty context
continue
prediction = self.forward(context).squeeze(0).cpu().numpy()[0]
decoded_data.append(prediction + encoded_data[i])
return decoded_data

75
models.py Normal file
View File

@ -0,0 +1,75 @@
import torch
import torch.nn as nn
def get_activation(name):
activations = {
'ReLU': nn.ReLU,
'Sigmoid': nn.Sigmoid,
'Tanh': nn.Tanh,
'LeakyReLU': nn.LeakyReLU,
'ELU': nn.ELU,
'None': nn.Identity
}
return activations[name]()
class LatentProjector(nn.Module):
def __init__(self, input_size, latent_size, layer_shapes, activations):
super(LatentProjector, self).__init__()
layers = []
in_features = input_size
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, latent_size))
self.fc = nn.Sequential(*layers)
self.latent_size = latent_size
def forward(self, x):
return self.fc(x)
class LatentRNNProjector(nn.Module):
def __init__(self, input_size, rnn_hidden_size, rnn_num_layers, latent_size):
super(LatentRNNProjector, self).__init__()
self.rnn = nn.LSTM(input_size, rnn_hidden_size, rnn_num_layers, batch_first=True)
self.fc = nn.Linear(rnn_hidden_size, latent_size)
self.latent_size = latent_size
def forward(self, x):
out, _ = self.rnn(x)
latent = self.fc(out)
return latent
class MiddleOut(nn.Module):
def __init__(self, latent_size, output_size, num_peers):
super(MiddleOut, self).__init__()
self.num_peers = num_peers
self.fc = nn.Linear(latent_size * 2 + 1, output_size)
def forward(self, my_latent, peer_latents, peer_correlations):
new_latents = []
for peer_latent, correlation in zip(peer_latents, peer_correlations):
combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1)
new_latent = self.fc(combined_input)
new_latents.append(new_latent)
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
return averaged_latent
class Predictor(nn.Module):
def __init__(self, output_size, layer_shapes, activations):
super(Predictor, self).__init__()
layers = []
in_features = output_size
for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None':
layers.append(get_activation(activations[i]))
in_features = out_features
layers.append(nn.Linear(in_features, 1))
self.fc = nn.Sequential(*layers)
def forward(self, latent):
return self.fc(latent)

121
train.py
View File

@ -1,121 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
import random
import os
import pickle
from data_processing import delta_encode, delta_decode, save_wav
from utils import visualize_prediction, plot_delta_distribution
from bitstream import ArithmeticEncoder
def pad_sequence(sequence, max_length):
padded_seq = np.zeros((max_length, *sequence.shape[1:]))
padded_seq[:sequence.shape[0], ...] = sequence
return padded_seq
def evaluate_model(model, data, use_delta_encoding, encoder, sample_rate=19531, epoch=0):
compression_ratios = []
identical_count = 0
all_deltas = []
for i, file_data in enumerate(data):
file_data = torch.tensor(file_data, dtype=torch.float32).unsqueeze(1).to(model.device)
encoded_data = model.encode(file_data.squeeze(1).cpu().numpy())
encoder.build_model(encoded_data)
compressed_data = encoder.encode(encoded_data)
decompressed_data = encoder.decode(compressed_data, len(encoded_data))
if use_delta_encoding:
decompressed_data = delta_decode(decompressed_data)
# Ensure the lengths match
min_length = min(len(file_data), len(decompressed_data))
file_data = file_data[:min_length]
decompressed_data = decompressed_data[:min_length]
identical = np.allclose(file_data.cpu().numpy(), decompressed_data, atol=1e-5)
if identical:
identical_count += 1
compression_ratio = len(file_data) / len(compressed_data)
compression_ratios.append(compression_ratio)
predicted_data = model(torch.tensor(encoded_data, dtype=torch.float32).unsqueeze(1).to(model.device)).squeeze(1).detach().cpu().numpy()
if use_delta_encoding:
predicted_data = delta_decode(predicted_data)
# Ensure predicted_data is a flat list of floats
predicted_data = predicted_data[:min_length]
delta_data = [file_data[i].item() - predicted_data[i] for i in range(min_length)]
all_deltas.extend(delta_data)
if i == (epoch % len(data)):
visualize_prediction(file_data.cpu().numpy(), predicted_data, delta_data, sample_rate, epoch=epoch)
identical_percentage = (identical_count / len(data)) * 100
delta_plot_path = plot_delta_distribution(all_deltas, epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
return compression_ratios, identical_percentage
def train_model(model, train_data, test_data, epochs, batch_size, learning_rate, use_delta_encoding, encoder, eval_freq, save_path):
wandb.init(project="wav-compression")
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_test_score = float('inf')
model.to(model.device)
max_length = max([len(seq) for seq in train_data])
print(f"Max sequence length: {max_length}")
for epoch in range(epochs):
total_loss = 0
random.shuffle(train_data)
for i in range(0, len(train_data) - batch_size, batch_size):
batch_data = [pad_sequence(np.array(train_data[j]), max_length) for j in range(i, i+batch_size)]
batch_data = np.array(batch_data)
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
targets = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(model.device)
outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss}')
if (epoch + 1) % eval_freq == 0:
train_compression_ratios, train_identical_percentage = evaluate_model(model, train_data, use_delta_encoding, encoder, epoch=epoch)
test_compression_ratios, test_identical_percentage = evaluate_model(model, test_data, use_delta_encoding, encoder, epoch=epoch)
wandb.log({
"train_compression_ratio_mean": np.mean(train_compression_ratios),
"train_compression_ratio_std": np.std(train_compression_ratios),
"train_compression_ratio_min": np.min(train_compression_ratios),
"train_compression_ratio_max": np.max(train_compression_ratios),
"test_compression_ratio_mean": np.mean(test_compression_ratios),
"test_compression_ratio_std": np.std(test_compression_ratios),
"test_compression_ratio_min": np.min(test_compression_ratios),
"test_compression_ratio_max": np.max(test_compression_ratios),
"train_identical_percentage": train_identical_percentage,
"test_identical_percentage": test_identical_percentage,
}, step=epoch)
print(f'Epoch {epoch+1}/{epochs}, Train Compression Ratio: Mean={np.mean(train_compression_ratios)}, Std={np.std(train_compression_ratios)}, Min={np.min(train_compression_ratios)}, Max={np.max(train_compression_ratios)}, Identical={train_identical_percentage}%')
print(f'Epoch {epoch+1}/{epochs}, Test Compression Ratio: Mean={np.mean(test_compression_ratios)}, Std={np.std(test_compression_ratios)}, Min={np.min(test_compression_ratios)}, Max={np.max(test_compression_ratios)}, Identical={test_identical_percentage}%')
test_score = np.mean(test_compression_ratios)
if test_score < best_test_score:
best_test_score = test_score
model_path = os.path.join(save_path, f"best_model_epoch_{epoch+1}.pt")
encoder_path = os.path.join(save_path, f"best_encoder_epoch_{epoch+1}.pkl")
torch.save(model.state_dict(), model_path)
with open(encoder_path, 'wb') as f:
pickle.dump(encoder, f)
print(f'New highscore on test data! Model and encoder saved to {model_path} and {encoder_path}')