diff --git a/config.yaml b/config.yaml index 9a12314..8568d5d 100644 --- a/config.yaml +++ b/config.yaml @@ -45,22 +45,23 @@ latent_projector: type: rnn # Options: 'fc', 'rnn' input_size: 19531 # =1s Input size for the Latent Projector (length of snippets). latent_size: 8 # Size of the latent representation before message passing. - layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc'). - activations: ['relu', 'relu'] # Activation functions for the latent projector layers (if type is 'fc'). - rnn_hidden_size: 16 # Hidden size for the RNN projector (if type is 'rnn'). - rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn'). + #layer_shapes: [256, 32] # List of layer sizes for the latent projector (if type is 'fc'). + #activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers (if type is 'fc'). + rnn_hidden_size: 12 # Hidden size for the RNN projector (if type is 'rnn'). + rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn'). middle_out: output_size: 8 # Size of the latent representation after message passing. - num_peers: 8 # Number of most correlated peers to consider. + num_peers: 3 # Number of most correlated peers to consider. predictor: layer_shapes: [8, 4] # List of layer sizes for the predictor. - activations: ['relu', 'none'] # Activation functions for the predictor layers. + activations: ['ReLU', 'None'] # Activation functions for the predictor layers. training: epochs: 128 # Number of training epochs. - batch_size: 8 # Batch size for training. + batch_size: 64 # Batch size for training. + num_batches: 16 # Batches per epoch learning_rate: 0.001 # Learning rate for the optimizer. eval_freq: 8 # Frequency of evaluation during training (in epochs). save_path: models # Directory to save the best model and encoder. @@ -76,7 +77,7 @@ data: url: https://content.neuralink.com/compression-challenge/data.zip # URL to download the dataset. directory: data # Directory to extract and store the dataset. split_ratio: 0.8 # Ratio to split the data into train and test sets. - cut_length: None # Optional length to cut sequences to. + cut_length: null # Optional length to cut sequences to. profiler: enable: false \ No newline at end of file diff --git a/data_processing.py b/data_processing.py index 87163e7..41102ff 100644 --- a/data_processing.py +++ b/data_processing.py @@ -21,18 +21,20 @@ def load_all_wavs(data_dir, cut_length=None): all_data = [] for file_path in wav_files: _, data = load_wav(file_path) - if cut_length: + if cut_length is not None: + print(cut_length) data = data[:cut_length] all_data.append(data) return all_data def compute_correlation_matrix(data): num_leads = len(data) - corr_matrix = np.zeros((num_leads, num_leads)) - for i in range(num_leads): - for j in range(num_leads): - if i != j: - corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1] + min_length = min(len(d) for d in data) + + # Trim all leads to the minimum length + trimmed_data = [d[:min_length] for d in data] + + corr_matrix = np.corrcoef(trimmed_data) return corr_matrix def split_data_by_time(data, split_ratio=0.5): diff --git a/main.py b/main.py index 21ea06f..f7ced92 100644 --- a/main.py +++ b/main.py @@ -2,21 +2,24 @@ import os import torch import torch.nn as nn import numpy as np -import random -from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution +import random, math +from utils import visualize_prediction, plot_delta_distribution +from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder import wandb -from pycallgraph import PyCallGraph -from pycallgraph.output import GraphvizOutput -import slate +from pycallgraph2 import PyCallGraph +from pycallgraph2.output import GraphvizOutput +from slate import Slate, Slate_Runner device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -class SpikeRunner: - def __init__(self, config): - self.config = config - self.name = slate.consume(config, 'name', default='Test') +class SpikeRunner(Slate_Runner): + def setup(self, name): + print("Setup SpikeRunner") + + self.name = name + slate, config = self.slate, self.config training_config = slate.consume(config, 'training', expand=True) data_config = slate.consume(config, 'data', expand=True) @@ -30,22 +33,36 @@ class SpikeRunner: self.train_data, self.test_data = split_data_by_time(all_data, split_ratio) # Compute correlation matrix + print("Computing correlation matrix") self.correlation_matrix = compute_correlation_matrix(self.train_data) + # Number of peers for message passing + self.num_peers = slate.consume(config, 'middle_out.num_peers') + + # Precompute sorted indices for the top num_peers correlated leads + print("Precomputing sorted peer indices") + self.sorted_peer_indices = np.argsort(-self.correlation_matrix, axis=1)[:, :self.num_peers] + # Model setup + print("Setting up models") latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc') + latent_size = slate.consume(config, 'latent_projector.latent_size') + input_size = slate.consume(config, 'latent_projector.input_size') + output_size = slate.consume(config, 'middle_out.output_size') if latent_projector_type == 'fc': - self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device) + self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) elif latent_projector_type == 'rnn': - self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device) + self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) - self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device) - self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device) + self.middle_out = MiddleOut(latent_size=latent_size, output_size=output_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) + self.predictor = Predictor(output_size=output_size, **slate.consume(config, 'predictor', expand=True)).to(device) # Training parameters + self.input_size = input_size self.epochs = slate.consume(training_config, 'epochs') self.batch_size = slate.consume(training_config, 'batch_size') + self.num_batches = slate.consume(training_config, 'num_batches') self.learning_rate = slate.consume(training_config, 'learning_rate') self.eval_freq = slate.consume(training_config, 'eval_freq') self.save_path = slate.consume(training_config, 'save_path') @@ -65,6 +82,7 @@ class SpikeRunner: # Optimizer self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.criterion = torch.nn.MSELoss() + print("SpikeRunner initialization complete") def run(self, run, forceNoProfile=False): if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile: @@ -77,58 +95,88 @@ class SpikeRunner: self.train_model() def train_model(self): - max_length = max([len(seq) for seq in self.train_data]) - print(f"Max sequence length: {max_length}") - + min_length = min([len(seq) for seq in self.train_data]) + best_test_score = float('inf') for epoch in range(self.epochs): total_loss = 0 - random.shuffle(self.train_data) - for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size): - batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data]) - inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device) - - batch_loss = 0 - for lead_idx in range(len(inputs)): - lead_data = inputs[lead_idx] - latents = self.projector(lead_data) - - for t in range(latents.shape[0]): - my_latent = latents[t] - - peer_latents = [] - peer_correlations = [] - for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]: - peer_latent = latents[t] - peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device) - peer_latents.append(peer_latent) - peer_correlations.append(peer_correlation) - - peer_latents = torch.stack(peer_latents).to(device) - peer_correlations = torch.stack(peer_correlations).to(device) - new_latent = self.middle_out(my_latent, peer_latents, peer_correlations) - prediction = self.predictor(new_latent) - target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t] - - loss = self.criterion(prediction, target) - batch_loss += loss.item() - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - total_loss += batch_loss - + for batch_num in range(self.num_batches): + + # Create indices for training data and shuffle them + indices = list(range(len(self.train_data))) + random.shuffle(indices) + + stacked_segments = [] + peer_correlations = [] + targets = [] + + for idx in indices[:self.batch_size]: + lead_data = self.train_data[idx][:min_length] + + # Slide a window over the data with overlap + stride = max(1, self.input_size // 8) # Ensuring stride is at least 1 + for i in range(0, len(lead_data) - self.input_size-1, stride): + lead_segment = lead_data[i:i + self.input_size] + inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) + + # Collect the segments for the current lead and its peers + peer_segments = [] + for peer_idx in self.sorted_peer_indices[idx]: + peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length] + peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device)) + peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) # Shape: (num_peers) + peer_correlations.append(peer_correlation) + + # Stack the segments to form the batch + stacked_segment = torch.stack([inputs] + peer_segments).to(device) + stacked_segments.append(stacked_segment) + target = lead_data[i + self.input_size + 1] + targets.append(target) + + # Pass the batch through the projector + latents = self.projector(torch.stack(stacked_segments)) + + my_latent = latents[:, 0, :] + peer_latents = latents[:, 1:, :] + + # Pass through MiddleOut + new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_correlations)) + prediction = self.predictor(new_latent) + + # Calculate loss and backpropagate + loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)) + total_loss += loss.item() + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch) - print(f'Epoch {epoch+1}/{self.epochs}, Loss: {total_loss}') - - if (epoch + 1) % self.eval_freq == 0: + print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') + + if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: + print(f'Starting evaluation for epoch {epoch + 1}') test_loss = self.evaluate_model(epoch) if test_loss < best_test_score: best_test_score = test_loss self.save_models(epoch) + print(f'Evaluation complete for epoch {epoch + 1}') + + + wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch) + print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') + + if (epoch + 1) % self.eval_freq == 0: + print(f'Starting evaluation for epoch {epoch + 1}') + test_loss = self.evaluate_model(epoch) + if test_loss < best_test_score: + best_test_score = test_loss + self.save_models(epoch) + print(f'Evaluation complete for epoch {epoch + 1}') + def evaluate_model(self, epoch): + print('Evaluating model...') self.projector.eval() self.middle_out.eval() self.predictor.eval() @@ -143,59 +191,82 @@ class SpikeRunner: with torch.no_grad(): for lead_idx in range(len(self.test_data)): - lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device) - latents = self.projector(lead_data) - + lead_data = self.test_data[lead_idx] true_data = [] predicted_data = [] delta_data = [] + targets = [] - for t in range(latents.shape[0]): - my_latent = latents[t] + min_length = min([len(seq) for seq in self.test_data]) - peer_latents = [] - peer_correlations = [] - for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]: - peer_latent = latents[t] - peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device) - peer_latents.append(peer_latent) - peer_correlations.append(peer_correlation) + # Initialize lists to store segments and peer correlations + stacked_segments = [] + peer_correlations = [] - peer_latents = torch.stack(peer_latents).to(device) - peer_correlations = torch.stack(peer_correlations).to(device) - new_latent = self.middle_out(my_latent, peer_latents, peer_correlations) - prediction = self.predictor(new_latent) - target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t] + for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8): + lead_segment = lead_data[i:i + self.input_size] + inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) - loss = self.criterion(prediction, target) - total_loss += loss.item() + # Collect peer segments and correlations + peer_segments = [] + for peer_idx in self.sorted_peer_indices[lead_idx]: + peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length] + peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device)) + peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device) + peer_correlations.append(peer_correlation) - true_data.append(target.cpu().numpy()) - predicted_data.append(prediction.cpu().numpy()) - delta_data.append((target - prediction).cpu().numpy()) + # Stack segments to form the batch + stacked_segment = torch.stack([inputs] + peer_segments).to(device) + stacked_segments.append(stacked_segment) + target = lead_data[i + self.input_size + 1] + targets.append(target) + # Pass the batch through the projector + latents = self.projector(torch.stack(stacked_segments)) + + my_latents = latents[:, 0, :] + peer_latents = latents[:, 1:, :] + + # Pass through MiddleOut + new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_correlations)) + + # Predict using the predictor + predictions = self.predictor(new_latents) + + # Compute loss and store true and predicted data + for i, segment in enumerate(stacked_segments): + for t in range(self.input_size): + target = torch.tensor(targets[i]) + true_data.append(target.cpu().numpy()) + predicted_data.append(predictions[i, t, :].cpu().numpy()) + delta_data.append((target - predictions[i, t, :]).cpu().numpy()) + + loss = self.criterion(predictions[i, t, :], target) + total_loss += loss.item() + + # Append true and predicted data for this lead sequence all_true.append(true_data) all_predicted.append(predicted_data) all_deltas.append(delta_data) if self.full_compression: - self.encoder.build_model(latents.cpu().numpy()) - compressed_data = self.encoder.encode(latents.cpu().numpy()) - decompressed_data = self.encoder.decode(compressed_data, len(latents)) - compression_ratio = len(latents) / len(compressed_data) + # Bitstream encoding + self.encoder.build_model(my_latents.cpu().numpy()) + compressed_data = self.encoder.encode(my_latents.cpu().numpy()) + decompressed_data = self.encoder.decode(compressed_data, len(my_latents)) + compression_ratio = len(my_latents) / len(compressed_data) compression_ratios.append(compression_ratio) # Check if decompressed data matches the original data - if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5): + if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5): exact_matches += 1 total_sequences += 1 - visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch) - avg_loss = total_loss / len(self.test_data) print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') wandb.log({"evaluation_loss": avg_loss}, step=epoch) + # Visualize delta distribution delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch) wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) @@ -207,14 +278,19 @@ class SpikeRunner: wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) + print('Evaluation done for this epoch.') return avg_loss + def save_models(self, epoch): + print('Saving models...') torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt")) torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt")) torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt")) print(f"New high score! Models saved at epoch {epoch+1}.") if __name__ == '__main__': + print('Initializing...') slate = Slate({'spikey': SpikeRunner}) slate.from_args() + print('Done.') \ No newline at end of file diff --git a/models.py b/models.py index 90221c4..cf9971d 100644 --- a/models.py +++ b/models.py @@ -49,8 +49,9 @@ class MiddleOut(nn.Module): def forward(self, my_latent, peer_latents, peer_correlations): new_latents = [] - for peer_latent, correlation in zip(peer_latents, peer_correlations): - combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1) + for p in range(peer_latents.shape[-2]): + peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p] + combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1) new_latent = self.fc(combined_input) new_latents.append(new_latent)