Fixed bugs
This commit is contained in:
		
							parent
							
								
									97de63e946
								
							
						
					
					
						commit
						bc783b9888
					
				
							
								
								
									
										17
									
								
								config.yaml
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								config.yaml
									
									
									
									
									
								
							@ -45,22 +45,23 @@ latent_projector:
 | 
				
			|||||||
  type: rnn  # Options: 'fc', 'rnn'
 | 
					  type: rnn  # Options: 'fc', 'rnn'
 | 
				
			||||||
  input_size: 19531  # =1s Input size for the Latent Projector (length of snippets).
 | 
					  input_size: 19531  # =1s Input size for the Latent Projector (length of snippets).
 | 
				
			||||||
  latent_size: 8  # Size of the latent representation before message passing.
 | 
					  latent_size: 8  # Size of the latent representation before message passing.
 | 
				
			||||||
  layer_shapes: [256, 32]  # List of layer sizes for the latent projector (if type is 'fc').
 | 
					  #layer_shapes: [256, 32]  # List of layer sizes for the latent projector (if type is 'fc').
 | 
				
			||||||
  activations: ['relu', 'relu']  # Activation functions for the latent projector layers (if type is 'fc').
 | 
					  #activations: ['ReLU', 'ReLU']  # Activation functions for the latent projector layers (if type is 'fc').
 | 
				
			||||||
  rnn_hidden_size: 16  # Hidden size for the RNN projector (if type is 'rnn').
 | 
					  rnn_hidden_size: 12  # Hidden size for the RNN projector (if type is 'rnn').
 | 
				
			||||||
  rnn_num_layers: 2  # Number of layers for the RNN projector (if type is 'rnn').
 | 
					  rnn_num_layers: 1  # Number of layers for the RNN projector (if type is 'rnn').
 | 
				
			||||||
 | 
					
 | 
				
			||||||
middle_out:
 | 
					middle_out:
 | 
				
			||||||
  output_size: 8  # Size of the latent representation after message passing.
 | 
					  output_size: 8  # Size of the latent representation after message passing.
 | 
				
			||||||
  num_peers: 8  # Number of most correlated peers to consider.
 | 
					  num_peers: 3  # Number of most correlated peers to consider.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
predictor:
 | 
					predictor:
 | 
				
			||||||
  layer_shapes: [8, 4]  # List of layer sizes for the predictor.
 | 
					  layer_shapes: [8, 4]  # List of layer sizes for the predictor.
 | 
				
			||||||
  activations: ['relu', 'none']  # Activation functions for the predictor layers.
 | 
					  activations: ['ReLU', 'None']  # Activation functions for the predictor layers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
training:
 | 
					training:
 | 
				
			||||||
  epochs: 128  # Number of training epochs.
 | 
					  epochs: 128  # Number of training epochs.
 | 
				
			||||||
  batch_size: 8  # Batch size for training.
 | 
					  batch_size: 64  # Batch size for training.
 | 
				
			||||||
 | 
					  num_batches: 16 # Batches per epoch
 | 
				
			||||||
  learning_rate: 0.001  # Learning rate for the optimizer.
 | 
					  learning_rate: 0.001  # Learning rate for the optimizer.
 | 
				
			||||||
  eval_freq: 8  # Frequency of evaluation during training (in epochs).
 | 
					  eval_freq: 8  # Frequency of evaluation during training (in epochs).
 | 
				
			||||||
  save_path: models  # Directory to save the best model and encoder.
 | 
					  save_path: models  # Directory to save the best model and encoder.
 | 
				
			||||||
@ -76,7 +77,7 @@ data:
 | 
				
			|||||||
  url: https://content.neuralink.com/compression-challenge/data.zip  # URL to download the dataset.
 | 
					  url: https://content.neuralink.com/compression-challenge/data.zip  # URL to download the dataset.
 | 
				
			||||||
  directory: data  # Directory to extract and store the dataset.
 | 
					  directory: data  # Directory to extract and store the dataset.
 | 
				
			||||||
  split_ratio: 0.8  # Ratio to split the data into train and test sets.
 | 
					  split_ratio: 0.8  # Ratio to split the data into train and test sets.
 | 
				
			||||||
  cut_length: None  # Optional length to cut sequences to.
 | 
					  cut_length: null  # Optional length to cut sequences to.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
profiler:
 | 
					profiler:
 | 
				
			||||||
  enable: false
 | 
					  enable: false
 | 
				
			||||||
@ -21,18 +21,20 @@ def load_all_wavs(data_dir, cut_length=None):
 | 
				
			|||||||
    all_data = []
 | 
					    all_data = []
 | 
				
			||||||
    for file_path in wav_files:
 | 
					    for file_path in wav_files:
 | 
				
			||||||
        _, data = load_wav(file_path)
 | 
					        _, data = load_wav(file_path)
 | 
				
			||||||
        if cut_length:
 | 
					        if cut_length is not None:
 | 
				
			||||||
 | 
					            print(cut_length)
 | 
				
			||||||
            data = data[:cut_length]
 | 
					            data = data[:cut_length]
 | 
				
			||||||
        all_data.append(data)
 | 
					        all_data.append(data)
 | 
				
			||||||
    return all_data
 | 
					    return all_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def compute_correlation_matrix(data):
 | 
					def compute_correlation_matrix(data):
 | 
				
			||||||
    num_leads = len(data)
 | 
					    num_leads = len(data)
 | 
				
			||||||
    corr_matrix = np.zeros((num_leads, num_leads))
 | 
					    min_length = min(len(d) for d in data)
 | 
				
			||||||
    for i in range(num_leads):
 | 
					    
 | 
				
			||||||
        for j in range(num_leads):
 | 
					    # Trim all leads to the minimum length
 | 
				
			||||||
            if i != j:
 | 
					    trimmed_data = [d[:min_length] for d in data]
 | 
				
			||||||
                corr_matrix[i, j] = np.corrcoef(data[i], data[j])[0, 1]
 | 
					
 | 
				
			||||||
 | 
					    corr_matrix = np.corrcoef(trimmed_data)
 | 
				
			||||||
    return corr_matrix
 | 
					    return corr_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def split_data_by_time(data, split_ratio=0.5):
 | 
					def split_data_by_time(data, split_ratio=0.5):
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										208
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										208
									
								
								main.py
									
									
									
									
									
								
							@ -2,21 +2,24 @@ import os
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import random
 | 
					import random, math
 | 
				
			||||||
from utils import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix, visualize_prediction, plot_delta_distribution
 | 
					from utils import visualize_prediction, plot_delta_distribution
 | 
				
			||||||
 | 
					from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix
 | 
				
			||||||
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
 | 
					from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
 | 
				
			||||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
 | 
					from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
 | 
				
			||||||
import wandb
 | 
					import wandb
 | 
				
			||||||
from pycallgraph import PyCallGraph
 | 
					from pycallgraph2 import PyCallGraph
 | 
				
			||||||
from pycallgraph.output import GraphvizOutput
 | 
					from pycallgraph2.output import GraphvizOutput
 | 
				
			||||||
import slate
 | 
					from slate import Slate, Slate_Runner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
					device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SpikeRunner:
 | 
					class SpikeRunner(Slate_Runner):
 | 
				
			||||||
    def __init__(self, config):
 | 
					    def setup(self, name):
 | 
				
			||||||
        self.config = config
 | 
					        print("Setup SpikeRunner")
 | 
				
			||||||
        self.name = slate.consume(config, 'name', default='Test')
 | 
					
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        slate, config = self.slate, self.config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        training_config = slate.consume(config, 'training', expand=True)
 | 
					        training_config = slate.consume(config, 'training', expand=True)
 | 
				
			||||||
        data_config = slate.consume(config, 'data', expand=True)
 | 
					        data_config = slate.consume(config, 'data', expand=True)
 | 
				
			||||||
@ -30,22 +33,36 @@ class SpikeRunner:
 | 
				
			|||||||
        self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
 | 
					        self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Compute correlation matrix
 | 
					        # Compute correlation matrix
 | 
				
			||||||
 | 
					        print("Computing correlation matrix")
 | 
				
			||||||
        self.correlation_matrix = compute_correlation_matrix(self.train_data)
 | 
					        self.correlation_matrix = compute_correlation_matrix(self.train_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Number of peers for message passing
 | 
				
			||||||
 | 
					        self.num_peers = slate.consume(config, 'middle_out.num_peers')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Precompute sorted indices for the top num_peers correlated leads
 | 
				
			||||||
 | 
					        print("Precomputing sorted peer indices")
 | 
				
			||||||
 | 
					        self.sorted_peer_indices = np.argsort(-self.correlation_matrix, axis=1)[:, :self.num_peers]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Model setup
 | 
					        # Model setup
 | 
				
			||||||
 | 
					        print("Setting up models")
 | 
				
			||||||
        latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
 | 
					        latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
 | 
				
			||||||
 | 
					        latent_size = slate.consume(config, 'latent_projector.latent_size')
 | 
				
			||||||
 | 
					        input_size = slate.consume(config, 'latent_projector.input_size')
 | 
				
			||||||
 | 
					        output_size = slate.consume(config, 'middle_out.output_size')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if latent_projector_type == 'fc':
 | 
					        if latent_projector_type == 'fc':
 | 
				
			||||||
            self.projector = LatentProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
 | 
					            self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
 | 
				
			||||||
        elif latent_projector_type == 'rnn':
 | 
					        elif latent_projector_type == 'rnn':
 | 
				
			||||||
            self.projector = LatentRNNProjector(**slate.consume(config, 'latent_projector', expand=True)).to(device)
 | 
					            self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.middle_out = MiddleOut(**slate.consume(config, 'middle_out', expand=True)).to(device)
 | 
					        self.middle_out = MiddleOut(latent_size=latent_size, output_size=output_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
 | 
				
			||||||
        self.predictor = Predictor(**slate.consume(config, 'predictor', expand=True)).to(device)
 | 
					        self.predictor = Predictor(output_size=output_size, **slate.consume(config, 'predictor', expand=True)).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Training parameters
 | 
					        # Training parameters
 | 
				
			||||||
 | 
					        self.input_size = input_size
 | 
				
			||||||
        self.epochs = slate.consume(training_config, 'epochs')
 | 
					        self.epochs = slate.consume(training_config, 'epochs')
 | 
				
			||||||
        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
					        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
				
			||||||
 | 
					        self.num_batches = slate.consume(training_config, 'num_batches')
 | 
				
			||||||
        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
					        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
				
			||||||
        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
					        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
				
			||||||
        self.save_path = slate.consume(training_config, 'save_path')
 | 
					        self.save_path = slate.consume(training_config, 'save_path')
 | 
				
			||||||
@ -65,6 +82,7 @@ class SpikeRunner:
 | 
				
			|||||||
        # Optimizer
 | 
					        # Optimizer
 | 
				
			||||||
        self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
 | 
					        self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
 | 
				
			||||||
        self.criterion = torch.nn.MSELoss()
 | 
					        self.criterion = torch.nn.MSELoss()
 | 
				
			||||||
 | 
					        print("SpikeRunner initialization complete")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, run, forceNoProfile=False):
 | 
					    def run(self, run, forceNoProfile=False):
 | 
				
			||||||
        if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
 | 
					        if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
 | 
				
			||||||
@ -77,58 +95,88 @@ class SpikeRunner:
 | 
				
			|||||||
        self.train_model()
 | 
					        self.train_model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train_model(self):
 | 
					    def train_model(self):
 | 
				
			||||||
        max_length = max([len(seq) for seq in self.train_data])
 | 
					        min_length = min([len(seq) for seq in self.train_data])
 | 
				
			||||||
        print(f"Max sequence length: {max_length}")
 | 
					 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
        best_test_score = float('inf')
 | 
					        best_test_score = float('inf')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for epoch in range(self.epochs):
 | 
					        for epoch in range(self.epochs):
 | 
				
			||||||
            total_loss = 0
 | 
					            total_loss = 0
 | 
				
			||||||
            random.shuffle(self.train_data)
 | 
					            for batch_num in range(self.num_batches):
 | 
				
			||||||
            for i in range(0, len(self.train_data[0]) - self.input_size, self.input_size):
 | 
					 | 
				
			||||||
                batch_data = np.array([lead[i:i+self.input_size] for lead in self.train_data])
 | 
					 | 
				
			||||||
                inputs = torch.tensor(batch_data, dtype=torch.float32).unsqueeze(2).to(device)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                batch_loss = 0
 | 
					                # Create indices for training data and shuffle them
 | 
				
			||||||
                for lead_idx in range(len(inputs)):
 | 
					                indices = list(range(len(self.train_data)))
 | 
				
			||||||
                    lead_data = inputs[lead_idx]
 | 
					                random.shuffle(indices)
 | 
				
			||||||
                    latents = self.projector(lead_data)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    for t in range(latents.shape[0]):
 | 
					                stacked_segments = []
 | 
				
			||||||
                        my_latent = latents[t]
 | 
					 | 
				
			||||||
                        
 | 
					 | 
				
			||||||
                        peer_latents = []
 | 
					 | 
				
			||||||
                peer_correlations = []
 | 
					                peer_correlations = []
 | 
				
			||||||
                        for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
 | 
					                targets = []
 | 
				
			||||||
                            peer_latent = latents[t]
 | 
					
 | 
				
			||||||
                            peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
 | 
					                for idx in indices[:self.batch_size]:
 | 
				
			||||||
                            peer_latents.append(peer_latent)
 | 
					                    lead_data = self.train_data[idx][:min_length]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # Slide a window over the data with overlap
 | 
				
			||||||
 | 
					                    stride = max(1, self.input_size // 8)  # Ensuring stride is at least 1
 | 
				
			||||||
 | 
					                    for i in range(0, len(lead_data) - self.input_size-1, stride):
 | 
				
			||||||
 | 
					                        lead_segment = lead_data[i:i + self.input_size]
 | 
				
			||||||
 | 
					                        inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        # Collect the segments for the current lead and its peers
 | 
				
			||||||
 | 
					                        peer_segments = []
 | 
				
			||||||
 | 
					                        for peer_idx in self.sorted_peer_indices[idx]:
 | 
				
			||||||
 | 
					                            peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length]
 | 
				
			||||||
 | 
					                            peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
 | 
				
			||||||
 | 
					                        peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)  # Shape: (num_peers)
 | 
				
			||||||
                        peer_correlations.append(peer_correlation)
 | 
					                        peer_correlations.append(peer_correlation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        peer_latents = torch.stack(peer_latents).to(device)
 | 
					                        # Stack the segments to form the batch
 | 
				
			||||||
                        peer_correlations = torch.stack(peer_correlations).to(device)
 | 
					                        stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
				
			||||||
                        new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
 | 
					                        stacked_segments.append(stacked_segment)
 | 
				
			||||||
                        prediction = self.predictor(new_latent)
 | 
					                        target = lead_data[i + self.input_size + 1]
 | 
				
			||||||
                        target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
 | 
					                        targets.append(target)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        loss = self.criterion(prediction, target)
 | 
					                # Pass the batch through the projector
 | 
				
			||||||
                        batch_loss += loss.item()
 | 
					                latents = self.projector(torch.stack(stacked_segments))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                my_latent = latents[:, 0, :]
 | 
				
			||||||
 | 
					                peer_latents = latents[:, 1:, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Pass through MiddleOut
 | 
				
			||||||
 | 
					                new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_correlations))
 | 
				
			||||||
 | 
					                prediction = self.predictor(new_latent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Calculate loss and backpropagate
 | 
				
			||||||
 | 
					                loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device))
 | 
				
			||||||
 | 
					                total_loss += loss.item()
 | 
				
			||||||
                self.optimizer.zero_grad()
 | 
					                self.optimizer.zero_grad()
 | 
				
			||||||
                loss.backward()
 | 
					                loss.backward()
 | 
				
			||||||
                self.optimizer.step()
 | 
					                self.optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                total_loss += batch_loss
 | 
					            wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
 | 
				
			||||||
 | 
					            print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
 | 
				
			||||||
 | 
					                print(f'Starting evaluation for epoch {epoch + 1}')
 | 
				
			||||||
 | 
					                test_loss = self.evaluate_model(epoch)
 | 
				
			||||||
 | 
					                if test_loss < best_test_score:
 | 
				
			||||||
 | 
					                    best_test_score = test_loss
 | 
				
			||||||
 | 
					                    self.save_models(epoch)
 | 
				
			||||||
 | 
					                print(f'Evaluation complete for epoch {epoch + 1}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
 | 
					                wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
 | 
				
			||||||
                print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
 | 
					                print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if (epoch + 1) % self.eval_freq == 0:
 | 
					                if (epoch + 1) % self.eval_freq == 0:
 | 
				
			||||||
 | 
					                    print(f'Starting evaluation for epoch {epoch + 1}')
 | 
				
			||||||
                    test_loss = self.evaluate_model(epoch)
 | 
					                    test_loss = self.evaluate_model(epoch)
 | 
				
			||||||
                    if test_loss < best_test_score:
 | 
					                    if test_loss < best_test_score:
 | 
				
			||||||
                        best_test_score = test_loss
 | 
					                        best_test_score = test_loss
 | 
				
			||||||
                        self.save_models(epoch)
 | 
					                        self.save_models(epoch)
 | 
				
			||||||
 | 
					                    print(f'Evaluation complete for epoch {epoch + 1}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def evaluate_model(self, epoch):
 | 
					    def evaluate_model(self, epoch):
 | 
				
			||||||
 | 
					        print('Evaluating model...')
 | 
				
			||||||
        self.projector.eval()
 | 
					        self.projector.eval()
 | 
				
			||||||
        self.middle_out.eval()
 | 
					        self.middle_out.eval()
 | 
				
			||||||
        self.predictor.eval()
 | 
					        self.predictor.eval()
 | 
				
			||||||
@ -143,59 +191,82 @@ class SpikeRunner:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            for lead_idx in range(len(self.test_data)):
 | 
					            for lead_idx in range(len(self.test_data)):
 | 
				
			||||||
                lead_data = torch.tensor(self.test_data[lead_idx], dtype=torch.float32).unsqueeze(1).to(device)
 | 
					                lead_data = self.test_data[lead_idx]
 | 
				
			||||||
                latents = self.projector(lead_data)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                true_data = []
 | 
					                true_data = []
 | 
				
			||||||
                predicted_data = []
 | 
					                predicted_data = []
 | 
				
			||||||
                delta_data = []
 | 
					                delta_data = []
 | 
				
			||||||
 | 
					                targets = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for t in range(latents.shape[0]):
 | 
					                min_length = min([len(seq) for seq in self.test_data])
 | 
				
			||||||
                    my_latent = latents[t]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    peer_latents = []
 | 
					                # Initialize lists to store segments and peer correlations
 | 
				
			||||||
 | 
					                stacked_segments = []
 | 
				
			||||||
                peer_correlations = []
 | 
					                peer_correlations = []
 | 
				
			||||||
                    for peer_idx in np.argsort(self.correlation_matrix[lead_idx])[-self.num_peers:]:
 | 
					
 | 
				
			||||||
                        peer_latent = latents[t]
 | 
					                for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
 | 
				
			||||||
                        peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx]], dtype=torch.float32).to(device)
 | 
					                    lead_segment = lead_data[i:i + self.input_size]
 | 
				
			||||||
                        peer_latents.append(peer_latent)
 | 
					                    inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    # Collect peer segments and correlations
 | 
				
			||||||
 | 
					                    peer_segments = []
 | 
				
			||||||
 | 
					                    for peer_idx in self.sorted_peer_indices[lead_idx]:
 | 
				
			||||||
 | 
					                        peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
 | 
				
			||||||
 | 
					                        peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
 | 
				
			||||||
 | 
					                    peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
 | 
				
			||||||
                    peer_correlations.append(peer_correlation)
 | 
					                    peer_correlations.append(peer_correlation)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    peer_latents = torch.stack(peer_latents).to(device)
 | 
					                    # Stack segments to form the batch
 | 
				
			||||||
                    peer_correlations = torch.stack(peer_correlations).to(device)
 | 
					                    stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
				
			||||||
                    new_latent = self.middle_out(my_latent, peer_latents, peer_correlations)
 | 
					                    stacked_segments.append(stacked_segment)
 | 
				
			||||||
                    prediction = self.predictor(new_latent)
 | 
					                    target = lead_data[i + self.input_size + 1]
 | 
				
			||||||
                    target = lead_data[t+1] if t < latents.shape[0] - 1 else lead_data[t]
 | 
					                    targets.append(target)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    loss = self.criterion(prediction, target)
 | 
					                # Pass the batch through the projector
 | 
				
			||||||
 | 
					                latents = self.projector(torch.stack(stacked_segments))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                my_latents = latents[:, 0, :]
 | 
				
			||||||
 | 
					                peer_latents = latents[:, 1:, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Pass through MiddleOut
 | 
				
			||||||
 | 
					                new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_correlations))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Predict using the predictor
 | 
				
			||||||
 | 
					                predictions = self.predictor(new_latents)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Compute loss and store true and predicted data
 | 
				
			||||||
 | 
					                for i, segment in enumerate(stacked_segments):
 | 
				
			||||||
 | 
					                    for t in range(self.input_size):
 | 
				
			||||||
 | 
					                        target = torch.tensor(targets[i])
 | 
				
			||||||
 | 
					                        true_data.append(target.cpu().numpy())
 | 
				
			||||||
 | 
					                        predicted_data.append(predictions[i, t, :].cpu().numpy())
 | 
				
			||||||
 | 
					                        delta_data.append((target - predictions[i, t, :]).cpu().numpy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        loss = self.criterion(predictions[i, t, :], target)
 | 
				
			||||||
                        total_loss += loss.item()
 | 
					                        total_loss += loss.item()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    true_data.append(target.cpu().numpy())
 | 
					                # Append true and predicted data for this lead sequence
 | 
				
			||||||
                    predicted_data.append(prediction.cpu().numpy())
 | 
					 | 
				
			||||||
                    delta_data.append((target - prediction).cpu().numpy())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                all_true.append(true_data)
 | 
					                all_true.append(true_data)
 | 
				
			||||||
                all_predicted.append(predicted_data)
 | 
					                all_predicted.append(predicted_data)
 | 
				
			||||||
                all_deltas.append(delta_data)
 | 
					                all_deltas.append(delta_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if self.full_compression:
 | 
					                if self.full_compression:
 | 
				
			||||||
                    self.encoder.build_model(latents.cpu().numpy())
 | 
					                    # Bitstream encoding
 | 
				
			||||||
                    compressed_data = self.encoder.encode(latents.cpu().numpy())
 | 
					                    self.encoder.build_model(my_latents.cpu().numpy())
 | 
				
			||||||
                    decompressed_data = self.encoder.decode(compressed_data, len(latents))
 | 
					                    compressed_data = self.encoder.encode(my_latents.cpu().numpy())
 | 
				
			||||||
                    compression_ratio = len(latents) / len(compressed_data)
 | 
					                    decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
 | 
				
			||||||
 | 
					                    compression_ratio = len(my_latents) / len(compressed_data)
 | 
				
			||||||
                    compression_ratios.append(compression_ratio)
 | 
					                    compression_ratios.append(compression_ratio)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    # Check if decompressed data matches the original data
 | 
					                    # Check if decompressed data matches the original data
 | 
				
			||||||
                    if np.allclose(latents.cpu().numpy(), decompressed_data, atol=1e-5):
 | 
					                    if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
 | 
				
			||||||
                        exact_matches += 1
 | 
					                        exact_matches += 1
 | 
				
			||||||
                    total_sequences += 1
 | 
					                    total_sequences += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                visualize_prediction(np.array(true_data), np.array(predicted_data), np.array(delta_data), sample_rate=1, epoch=epoch)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        avg_loss = total_loss / len(self.test_data)
 | 
					        avg_loss = total_loss / len(self.test_data)
 | 
				
			||||||
        print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
 | 
					        print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
 | 
				
			||||||
        wandb.log({"evaluation_loss": avg_loss}, step=epoch)
 | 
					        wandb.log({"evaluation_loss": avg_loss}, step=epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Visualize delta distribution
 | 
				
			||||||
        delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
 | 
					        delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
 | 
				
			||||||
        wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
 | 
					        wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -207,14 +278,19 @@ class SpikeRunner:
 | 
				
			|||||||
            wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
 | 
					            wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
 | 
				
			||||||
            wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
 | 
					            wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print('Evaluation done for this epoch.')
 | 
				
			||||||
        return avg_loss
 | 
					        return avg_loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_models(self, epoch):
 | 
					    def save_models(self, epoch):
 | 
				
			||||||
 | 
					        print('Saving models...')
 | 
				
			||||||
        torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
 | 
					        torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
 | 
				
			||||||
        torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
 | 
					        torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
 | 
				
			||||||
        torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
 | 
					        torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
 | 
				
			||||||
        print(f"New high score! Models saved at epoch {epoch+1}.")
 | 
					        print(f"New high score! Models saved at epoch {epoch+1}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    print('Initializing...')
 | 
				
			||||||
    slate = Slate({'spikey': SpikeRunner})
 | 
					    slate = Slate({'spikey': SpikeRunner})
 | 
				
			||||||
    slate.from_args()
 | 
					    slate.from_args()
 | 
				
			||||||
 | 
					    print('Done.')
 | 
				
			||||||
@ -49,8 +49,9 @@ class MiddleOut(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def forward(self, my_latent, peer_latents, peer_correlations):
 | 
					    def forward(self, my_latent, peer_latents, peer_correlations):
 | 
				
			||||||
        new_latents = []
 | 
					        new_latents = []
 | 
				
			||||||
        for peer_latent, correlation in zip(peer_latents, peer_correlations):
 | 
					        for p in range(peer_latents.shape[-2]):
 | 
				
			||||||
            combined_input = torch.cat((my_latent, peer_latent, correlation), dim=-1)
 | 
					            peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
 | 
				
			||||||
 | 
					            combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
 | 
				
			||||||
            new_latent = self.fc(combined_input)
 | 
					            new_latent = self.fc(combined_input)
 | 
				
			||||||
            new_latents.append(new_latent)
 | 
					            new_latents.append(new_latent)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user