Fixed bugs

This commit is contained in:
Dominik Moritz Roth 2024-05-25 20:27:54 +02:00
parent 97de63e946
commit bc783b9888
4 changed files with 179 additions and 99 deletions

View File

@ -45,22 +45,23 @@ latent_projector:
type: rnn # Options: 'fc', 'rnn'
input_size: 19531 # =1s Input size for the Latent Projector (length of snippets).
latent_size: 8 # Size of the latent representation before message passing.
layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc').
activations: ['relu', 'relu'] # Activation functions for the latent projector layers (if type is 'fc').
rnn_hidden_size: 16 # Hidden size for the RNN projector (if type is 'rnn').
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
#layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc').
#activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc').
rnn_hidden_size: 12 # Hidden size for the RNN projector (if type is 'rnn').
rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn').
middle_out:
output_size: 8 # Size of the latent representation after message passing.
num_peers: 8 # Number of most correlated peers to consider.
num_peers: 3 # Number of most correlated peers to consider.
predictor:
layer_shapes: [8, 4] # List of layer sizes for the predictor.
activations: ['relu', 'none'] # Activation functions for the predictor layers.
activations: ['ReLU', 'None'] # Activation functions for the predictor layers.
training:
epochs: 128 # Number of training epochs.
batch_size: 8 # Batch size for training.
batch_size: 64 # Batch size for training.
num_batches: 16 # Batches per epoch
learning_rate: 0.001 # Learning rate for the optimizer.
eval_freq: 8 # Frequency of evaluation during training (in epochs).
save_path: models # Directory to save the best model and encoder.
@ -76,7 +77,7 @@ data:
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
directory: data # Directory to extract and store the dataset.
split_ratio: 0.8 # Ratio to split the data into train and test sets.
cut_length: None # Optional length to cut sequences to.
cut_length: null # Optional length to cut sequences to.
profiler:
enable: false

View File

@ -21,18 +21,20 @@ def load_all_wavs(data_dir, cut_length=None):
all_data = []
for file_path in wav_files:
_, data = load_wav(file_path)
if cut_length:
if cut_length is not None:
print(cut_length)
data = data[:cut_length]
all_data.append(data)
return all_data
def compute_correlation_matrix(data):
num_leads = len(data)
corr_matrix = np.zeros((num_leads, num_leads))
for i in range(num_leads):
for j in range(num_leads):
if i != j:
corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
min_length = min(len(d) for d in data)
# Trim all leads to the minimum length
trimmed_data = [d[:min_length] for d in data]
corr_matrix = np.corrcoef(trimmed_data)
return corr_matrix
def split_data_by_time(data, split_ratio=0.5):

242
main.py
View File

@ -2,21 +2,24 @@ import os
import torch
import torch.nn as nn
import numpy as np
import random
from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution
import random, math
from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
import wandb
from pycallgraph import PyCallGraph
from pycallgraph.output import GraphvizOutput
import slate
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
from slate import Slate, Slate_Runner
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class SpikeRunner:
def __init__(self, config):
self.config = config
self.name = slate.consume(config, 'name', default='Test')
class SpikeRunner(Slate_Runner):
def setup(self, name):
print("Setup SpikeRunner")
self.name = name
slate, config = self.slate, self.config
training_config = slate.consume(config, 'training', expand=True)
data_config = slate.consume(config, 'data', expand=True)
@ -30,22 +33,36 @@ class SpikeRunner:
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
# Compute correlation matrix
print("Computing correlation matrix")
self.correlation_matrix = compute_correlation_matrix(self.train_data)
# Number of peers for message passing
self.num_peers = slate.consume(config, 'middle_out.num_peers')
# Precompute sorted indices for the top num_peers correlated leads
print("Precomputing sorted peer indices")
self.sorted_peer_indices = np.argsort(-self.correlation_matrix, axis=1)[:, :self.num_peers]
# Model setup
print("Setting up models")
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size')
output_size = slate.consume(config, 'middle_out.output_size')
if latent_projector_type == 'fc':
self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device)
self.middle_out = MiddleOut(latent_size=latent_size, output_size=output_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(output_size=output_size, **slate.consume(config, 'predictor', expand=True)).to(device)
# Training parameters
self.input_size = input_size
self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size')
self.num_batches = slate.consume(training_config, 'num_batches')
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path')
@ -65,6 +82,7 @@ class SpikeRunner:
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss()
print("SpikeRunner initialization complete")
def run(self, run, forceNoProfile=False):
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
@ -77,58 +95,88 @@ class SpikeRunner:
self.train_model()
def train_model(self):
max_length = max([len(seq) for seq in self.train_data])
print(f"Max sequence length: {max_length}")
min_length = min([len(seq) for seq in self.train_data])
best_test_score = float('inf')
for epoch in range(self.epochs):
total_loss = 0
random.shuffle(self.train_data)
for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size):
batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data])
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device)
batch_loss = 0
for lead_idx in range(len(inputs)):
lead_data = inputs[lead_idx]
latents = self.projector(lead_data)
for t in range(latents.shape[0]):
my_latent = latents[t]
peer_latents = []
peer_correlations = []
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
peer_latents = torch.stack(peer_latents).to(device)
peer_correlations = torch.stack(peer_correlations).to(device)
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
prediction = self.predictor(new_latent)
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
loss = self.criterion(prediction, target)
batch_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += batch_loss
for batch_num in range(self.num_batches):
# Create indices for training data and shuffle them
indices = list(range(len(self.train_data)))
random.shuffle(indices)
stacked_segments = []
peer_correlations = []
targets = []
for idx in indices[:self.batch_size]:
lead_data = self.train_data[idx][:min_length]
# Slide a window over the data with overlap
stride = max(1, self.input_size // 8) # Ensuring stride is at least 1
for i in range(0, len(lead_data) - self.input_size-1, stride):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
# Collect the segments for the current lead and its peers
peer_segments = []
for peer_idx in self.sorted_peer_indices[idx]:
peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) # Shape: (num_peers)
peer_correlations.append(peer_correlation)
# Stack the segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1]
targets.append(target)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Pass through MiddleOut
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_correlations))
prediction = self.predictor(new_latent)
# Calculate loss and backpropagate
loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device))
total_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch+1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
def evaluate_model(self, epoch):
print('Evaluating model...')
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()
@ -143,59 +191,82 @@ class SpikeRunner:
with torch.no_grad():
for lead_idx in range(len(self.test_data)):
lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device)
latents = self.projector(lead_data)
lead_data = self.test_data[lead_idx]
true_data = []
predicted_data = []
delta_data = []
targets = []
for t in range(latents.shape[0]):
my_latent = latents[t]
min_length = min([len(seq) for seq in self.test_data])
peer_latents = []
peer_correlations = []
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
# Initialize lists to store segments and peer correlations
stacked_segments = []
peer_correlations = []
peer_latents = torch.stack(peer_latents).to(device)
peer_correlations = torch.stack(peer_correlations).to(device)
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
prediction = self.predictor(new_latent)
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
loss = self.criterion(prediction, target)
total_loss += loss.item()
# Collect peer segments and correlations
peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]:
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_correlations.append(peer_correlation)
true_data.append(target.cpu().numpy())
predicted_data.append(prediction.cpu().numpy())
delta_data.append((target - prediction).cpu().numpy())
# Stack segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1]
targets.append(target)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
my_latents = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Pass through MiddleOut
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_correlations))
# Predict using the predictor
predictions = self.predictor(new_latents)
# Compute loss and store true and predicted data
for i, segment in enumerate(stacked_segments):
for t in range(self.input_size):
target = torch.tensor(targets[i])
true_data.append(target.cpu().numpy())
predicted_data.append(predictions[i, t, :].cpu().numpy())
delta_data.append((target - predictions[i, t, :]).cpu().numpy())
loss = self.criterion(predictions[i, t, :], target)
total_loss += loss.item()
# Append true and predicted data for this lead sequence
all_true.append(true_data)
all_predicted.append(predicted_data)
all_deltas.append(delta_data)
if self.full_compression:
self.encoder.build_model(latents.cpu().numpy())
compressed_data = self.encoder.encode(latents.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(latents))
compression_ratio = len(latents) / len(compressed_data)
# Bitstream encoding
self.encoder.build_model(my_latents.cpu().numpy())
compressed_data = self.encoder.encode(my_latents.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
compression_ratio = len(my_latents) / len(compressed_data)
compression_ratios.append(compression_ratio)
# Check if decompressed data matches the original data
if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5):
if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1
total_sequences += 1
visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch)
avg_loss = total_loss / len(self.test_data)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
# Visualize delta distribution
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
@ -207,14 +278,19 @@ class SpikeRunner:
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
print('Evaluation done for this epoch.')
return avg_loss
def save_models(self, epoch):
print('Saving models...')
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.")
if __name__ == '__main__':
print('Initializing...')
slate = Slate({'spikey': SpikeRunner})
slate.from_args()
print('Done.')

View File

@ -49,8 +49,9 @@ class MiddleOut(nn.Module):
def forward(self, my_latent, peer_latents, peer_correlations):
new_latents = []
for peer_latent, correlation in zip(peer_latents, peer_correlations):
combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1)
for p in range(peer_latents.shape[-2]):
peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
new_latent = self.fc(combined_input)
new_latents.append(new_latent)