Fixed bugs

This commit is contained in:
Dominik Moritz Roth 2024-05-25 20:27:54 +02:00
parent 97de63e946
commit bc783b9888
4 changed files with 179 additions and 99 deletions

View File

@ -45,22 +45,23 @@ latent_projector:
type: rnn # Options: 'fc', 'rnn' type: rnn # Options: 'fc', 'rnn'
input_size: 19531 # =1s Input size for the Latent Projector (length of snippets). input_size: 19531 # =1s Input size for the Latent Projector (length of snippets).
latent_size: 8 # Size of the latent representation before message passing. latent_size: 8 # Size of the latent representation before message passing.
layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc'). #layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc').
activations: ['relu', 'relu'] # Activation functions for the latent projector layers (if type is 'fc'). #activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc').
rnn_hidden_size: 16 # Hidden size for the RNN projector (if type is 'rnn'). rnn_hidden_size: 12 # Hidden size for the RNN projector (if type is 'rnn').
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn'). rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn').
middle_out: middle_out:
output_size: 8 # Size of the latent representation after message passing. output_size: 8 # Size of the latent representation after message passing.
num_peers: 8 # Number of most correlated peers to consider. num_peers: 3 # Number of most correlated peers to consider.
predictor: predictor:
layer_shapes: [8, 4] # List of layer sizes for the predictor. layer_shapes: [8, 4] # List of layer sizes for the predictor.
activations: ['relu', 'none'] # Activation functions for the predictor layers. activations: ['ReLU', 'None'] # Activation functions for the predictor layers.
training: training:
epochs: 128 # Number of training epochs. epochs: 128 # Number of training epochs.
batch_size: 8 # Batch size for training. batch_size: 64 # Batch size for training.
num_batches: 16 # Batches per epoch
learning_rate: 0.001 # Learning rate for the optimizer. learning_rate: 0.001 # Learning rate for the optimizer.
eval_freq: 8 # Frequency of evaluation during training (in epochs). eval_freq: 8 # Frequency of evaluation during training (in epochs).
save_path: models # Directory to save the best model and encoder. save_path: models # Directory to save the best model and encoder.
@ -76,7 +77,7 @@ data:
url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset. url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset.
directory: data # Directory to extract and store the dataset. directory: data # Directory to extract and store the dataset.
split_ratio: 0.8 # Ratio to split the data into train and test sets. split_ratio: 0.8 # Ratio to split the data into train and test sets.
cut_length: None # Optional length to cut sequences to. cut_length: null # Optional length to cut sequences to.
profiler: profiler:
enable: false enable: false

View File

@ -21,18 +21,20 @@ def load_all_wavs(data_dir, cut_length=None):
all_data = [] all_data = []
for file_path in wav_files: for file_path in wav_files:
_, data = load_wav(file_path) _, data = load_wav(file_path)
if cut_length: if cut_length is not None:
print(cut_length)
data = data[:cut_length] data = data[:cut_length]
all_data.append(data) all_data.append(data)
return all_data return all_data
def compute_correlation_matrix(data): def compute_correlation_matrix(data):
num_leads = len(data) num_leads = len(data)
corr_matrix = np.zeros((num_leads, num_leads)) min_length = min(len(d) for d in data)
for i in range(num_leads):
for j in range(num_leads): # Trim all leads to the minimum length
if i != j: trimmed_data = [d[:min_length] for d in data]
corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
corr_matrix = np.corrcoef(trimmed_data)
return corr_matrix return corr_matrix
def split_data_by_time(data, split_ratio=0.5): def split_data_by_time(data, split_ratio=0.5):

224
main.py
View File

@ -2,21 +2,24 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import random import random, math
from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
import wandb import wandb
from pycallgraph import PyCallGraph from pycallgraph2 import PyCallGraph
from pycallgraph.output import GraphvizOutput from pycallgraph2.output import GraphvizOutput
import slate from slate import Slate, Slate_Runner
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class SpikeRunner: class SpikeRunner(Slate_Runner):
def __init__(self, config): def setup(self, name):
self.config = config print("Setup SpikeRunner")
self.name = slate.consume(config, 'name', default='Test')
self.name = name
slate, config = self.slate, self.config
training_config = slate.consume(config, 'training', expand=True) training_config = slate.consume(config, 'training', expand=True)
data_config = slate.consume(config, 'data', expand=True) data_config = slate.consume(config, 'data', expand=True)
@ -30,22 +33,36 @@ class SpikeRunner:
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio) self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
# Compute correlation matrix # Compute correlation matrix
print("Computing correlation matrix")
self.correlation_matrix = compute_correlation_matrix(self.train_data) self.correlation_matrix = compute_correlation_matrix(self.train_data)
# Number of peers for message passing
self.num_peers = slate.consume(config, 'middle_out.num_peers')
# Precompute sorted indices for the top num_peers correlated leads
print("Precomputing sorted peer indices")
self.sorted_peer_indices = np.argsort(-self.correlation_matrix, axis=1)[:, :self.num_peers]
# Model setup # Model setup
print("Setting up models")
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc') latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size')
output_size = slate.consume(config, 'middle_out.output_size')
if latent_projector_type == 'fc': if latent_projector_type == 'fc':
self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn': elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device) self.middle_out = MiddleOut(latent_size=latent_size, output_size=output_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device) self.predictor = Predictor(output_size=output_size, **slate.consume(config, 'predictor', expand=True)).to(device)
# Training parameters # Training parameters
self.input_size = input_size
self.epochs = slate.consume(training_config, 'epochs') self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size') self.batch_size = slate.consume(training_config, 'batch_size')
self.num_batches = slate.consume(training_config, 'num_batches')
self.learning_rate = slate.consume(training_config, 'learning_rate') self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq') self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path') self.save_path = slate.consume(training_config, 'save_path')
@ -65,6 +82,7 @@ class SpikeRunner:
# Optimizer # Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
print("SpikeRunner initialization complete")
def run(self, run, forceNoProfile=False): def run(self, run, forceNoProfile=False):
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile: if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
@ -77,58 +95,88 @@ class SpikeRunner:
self.train_model() self.train_model()
def train_model(self): def train_model(self):
max_length = max([len(seq) for seq in self.train_data]) min_length = min([len(seq) for seq in self.train_data])
print(f"Max sequence length: {max_length}")
best_test_score = float('inf') best_test_score = float('inf')
for epoch in range(self.epochs): for epoch in range(self.epochs):
total_loss = 0 total_loss = 0
random.shuffle(self.train_data) for batch_num in range(self.num_batches):
for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size):
batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data])
inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device)
batch_loss = 0 # Create indices for training data and shuffle them
for lead_idx in range(len(inputs)): indices = list(range(len(self.train_data)))
lead_data = inputs[lead_idx] random.shuffle(indices)
latents = self.projector(lead_data)
for t in range(latents.shape[0]): stacked_segments = []
my_latent = latents[t] peer_correlations = []
targets = []
peer_latents = [] for idx in indices[:self.batch_size]:
peer_correlations = [] lead_data = self.train_data[idx][:min_length]
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
peer_latents = torch.stack(peer_latents).to(device) # Slide a window over the data with overlap
peer_correlations = torch.stack(peer_correlations).to(device) stride = max(1, self.input_size // 8) # Ensuring stride is at least 1
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations) for i in range(0, len(lead_data) - self.input_size-1, stride):
prediction = self.predictor(new_latent) lead_segment = lead_data[i:i + self.input_size]
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t] inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
loss = self.criterion(prediction, target) # Collect the segments for the current lead and its peers
batch_loss += loss.item() peer_segments = []
self.optimizer.zero_grad() for peer_idx in self.sorted_peer_indices[idx]:
loss.backward() peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length]
self.optimizer.step() peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) # Shape: (num_peers)
peer_correlations.append(peer_correlation)
total_loss += batch_loss # Stack the segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1]
targets.append(target)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Pass through MiddleOut
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_correlations))
prediction = self.predictor(new_latent)
# Calculate loss and backpropagate
loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device))
total_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch) wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch+1}/{self.epochs}, Loss: {total_loss}') print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0: if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch) test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score: if test_loss < best_test_score:
best_test_score = test_loss best_test_score = test_loss
self.save_models(epoch) self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
def evaluate_model(self, epoch): def evaluate_model(self, epoch):
print('Evaluating model...')
self.projector.eval() self.projector.eval()
self.middle_out.eval() self.middle_out.eval()
self.predictor.eval() self.predictor.eval()
@ -143,59 +191,82 @@ class SpikeRunner:
with torch.no_grad(): with torch.no_grad():
for lead_idx in range(len(self.test_data)): for lead_idx in range(len(self.test_data)):
lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device) lead_data = self.test_data[lead_idx]
latents = self.projector(lead_data)
true_data = [] true_data = []
predicted_data = [] predicted_data = []
delta_data = [] delta_data = []
targets = []
for t in range(latents.shape[0]): min_length = min([len(seq) for seq in self.test_data])
my_latent = latents[t]
peer_latents = [] # Initialize lists to store segments and peer correlations
peer_correlations = [] stacked_segments = []
for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]: peer_correlations = []
peer_latent = latents[t]
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
peer_latents.append(peer_latent)
peer_correlations.append(peer_correlation)
peer_latents = torch.stack(peer_latents).to(device) for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
peer_correlations = torch.stack(peer_correlations).to(device) lead_segment = lead_data[i:i + self.input_size]
new_latent = self.middle_out(my_latent, peer_latents, peer_correlations) inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
prediction = self.predictor(new_latent)
target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
loss = self.criterion(prediction, target) # Collect peer segments and correlations
total_loss += loss.item() peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]:
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_correlations.append(peer_correlation)
true_data.append(target.cpu().numpy()) # Stack segments to form the batch
predicted_data.append(prediction.cpu().numpy()) stacked_segment = torch.stack([inputs] + peer_segments).to(device)
delta_data.append((target - prediction).cpu().numpy()) stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1]
targets.append(target)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments))
my_latents = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Pass through MiddleOut
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_correlations))
# Predict using the predictor
predictions = self.predictor(new_latents)
# Compute loss and store true and predicted data
for i, segment in enumerate(stacked_segments):
for t in range(self.input_size):
target = torch.tensor(targets[i])
true_data.append(target.cpu().numpy())
predicted_data.append(predictions[i, t, :].cpu().numpy())
delta_data.append((target - predictions[i, t, :]).cpu().numpy())
loss = self.criterion(predictions[i, t, :], target)
total_loss += loss.item()
# Append true and predicted data for this lead sequence
all_true.append(true_data) all_true.append(true_data)
all_predicted.append(predicted_data) all_predicted.append(predicted_data)
all_deltas.append(delta_data) all_deltas.append(delta_data)
if self.full_compression: if self.full_compression:
self.encoder.build_model(latents.cpu().numpy()) # Bitstream encoding
compressed_data = self.encoder.encode(latents.cpu().numpy()) self.encoder.build_model(my_latents.cpu().numpy())
decompressed_data = self.encoder.decode(compressed_data, len(latents)) compressed_data = self.encoder.encode(my_latents.cpu().numpy())
compression_ratio = len(latents) / len(compressed_data) decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
compression_ratio = len(my_latents) / len(compressed_data)
compression_ratios.append(compression_ratio) compression_ratios.append(compression_ratio)
# Check if decompressed data matches the original data # Check if decompressed data matches the original data
if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5): if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
exact_matches += 1 exact_matches += 1
total_sequences += 1 total_sequences += 1
visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch)
avg_loss = total_loss / len(self.test_data) avg_loss = total_loss / len(self.test_data)
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss}, step=epoch) wandb.log({"evaluation_loss": avg_loss}, step=epoch)
# Visualize delta distribution
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch) delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
@ -207,14 +278,19 @@ class SpikeRunner:
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
print('Evaluation done for this epoch.')
return avg_loss return avg_loss
def save_models(self, epoch): def save_models(self, epoch):
print('Saving models...')
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt")) torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt")) torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt")) torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.") print(f"New high score! Models saved at epoch {epoch+1}.")
if __name__ == '__main__': if __name__ == '__main__':
print('Initializing...')
slate = Slate({'spikey': SpikeRunner}) slate = Slate({'spikey': SpikeRunner})
slate.from_args() slate.from_args()
print('Done.')

View File

@ -49,8 +49,9 @@ class MiddleOut(nn.Module):
def forward(self, my_latent, peer_latents, peer_correlations): def forward(self, my_latent, peer_latents, peer_correlations):
new_latents = [] new_latents = []
for peer_latent, correlation in zip(peer_latents, peer_correlations): for p in range(peer_latents.shape[-2]):
combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1) peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
new_latent = self.fc(combined_input) new_latent = self.fc(combined_input)
new_latents.append(new_latent) new_latents.append(new_latent)