Updated README and var names
This commit is contained in:
parent
fb51634417
commit
fb8dbd30ee
25
README.md
25
README.md
@ -10,25 +10,35 @@ The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024
|
||||
|
||||
The `analysis.ipynb` notebook contains a detailed analysis of the data. We found that there is sometimes significant cross-correlation between the different leads, so we find it vital to use this information for better compression. This cross-correlation allows us to improve the accuracy of our predictions and reduce the overall amount of data that needs to be transmitted. As part of the analysis, we also note that achieving a 200x compression ratio is highly unlikely to be possible and is also nonsensical, a very close reproduction is sufficient.
|
||||
|
||||
## Compression Details
|
||||
## Algorithm Overview
|
||||
|
||||
### 1 - Thread Topology Reconstruction
|
||||
|
||||
As the first step we analyse reading from the leads to construct an approximative topology of the threads in the brain. The distance metric we generate only approximately represents true euclidean distances, but rather the 'distance' in common activity. This topology must only be computed once for a given implant and maybe updated fro thread movements, but is not part of the regular compression/decompression process.
|
||||
|
||||
### 2 - Predictive Architecture
|
||||
|
||||
The main workhorse of our compression approach is a predictive model running both in the compressor and decompressor. With good predictions of the data, only the error between the prediction and actual data nmust be transmitted. We make use of the previously constructed topology to allow the predictive model's latent to represent activity of brain regions based on the reading of the threads instead of just for threads themselves.
|
||||
|
||||
The solution leverages three neural network models to achieve effective compression:
|
||||
|
||||
1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) based on the configuration.
|
||||
1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) with arbitrary shape.
|
||||
|
||||
2. **MiddleOut (Message Passer)**: For each lead, this module looks up the `n` most correlated leads and uses their latent representations along with their correlation values to generate a new latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, correlation) -> new_latent and then averaging over all new_latent values to get the final representation.
|
||||
2. **MiddleOut (Message Passer)**: For each lead, this module perfroms message passing according to the thread topology. Their latent representations along with their distance metrics are used to generate joint latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, metcric) -> joint_latent and then averaging over all joint_latent values to get the final representation.
|
||||
|
||||
3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training.
|
||||
3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training. Can be configured to be an FCNN of arbitrary shape.
|
||||
|
||||
By accurately predicting the next timestep, the delta (difference) between the actual value and the predicted value is minimized. Small deltas mean that fewer bits are needed to store these values, which are then efficiently encoded using the bitstream encoder.
|
||||
The neural networks used in this solution are rather small, making it possible to meet the latency and power requirements if implemented more efficiently.
|
||||
|
||||
The neural networks used in this solution are tiny, making it possible to meet the latency and power requirements if implemented more efficiently.
|
||||
### 3 - Efficient Bitstream Encoding
|
||||
|
||||
Based on an expected distribution of deltas, that have to be transmitted an efficient huffman-like binary format is used for encoding of the data.
|
||||
|
||||
## TODO
|
||||
|
||||
- All currently implemented bitstream encoders are rather naive. We know, that lead values from the N1 only have 10 bit precision, but wav file provides yus with 32bit floats. All my bitstream encoders are also based on 32bit floats, discretizing back into the 10 bit space would be a low hanging fruit for ~3.2x compression.
|
||||
- Since we merely encode the remaining delta, we can go even more efficient by constructing something along the lines of a huffman tree.
|
||||
- Loss is not coming down during training...
|
||||
- Loss is not coming down during training... So basicaly nothing works right now. But the text I wrote is cool, right?
|
||||
- Make a logo
|
||||
|
||||
## Installation
|
||||
@ -56,4 +66,3 @@ To train the model, run:
|
||||
```bash
|
||||
python main.py <config_file.yaml> <exp_name>
|
||||
```
|
||||
|
||||
|
@ -27,16 +27,16 @@ def load_all_wavs(data_dir, cut_length=None):
|
||||
all_data.append(data)
|
||||
return all_data
|
||||
|
||||
def compute_correlation_matrix(data):
|
||||
def compute_topology_metrics(data):
|
||||
num_leads = len(data)
|
||||
min_length = min(len(d) for d in data)
|
||||
|
||||
# Trim all leads to the minimum length
|
||||
trimmed_data = [d[:min_length] for d in data]
|
||||
|
||||
corr_matrix = np.corrcoef(trimmed_data)
|
||||
np.fill_diagonal(corr_matrix, 0)
|
||||
return corr_matrix
|
||||
metric_matrix = np.corrcoef(trimmed_data)
|
||||
np.fill_diagonal(metric_matrix, 0)
|
||||
return metric_matrix
|
||||
|
||||
def split_data_by_time(data, split_ratio=0.5):
|
||||
train_data = []
|
||||
|
29
main.py
29
main.py
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
import random, math
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_correlation_matrix
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
||||
from models import LatentProjector, LatentRNNProjector, MiddleOut, Predictor
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
||||
import wandb
|
||||
@ -33,16 +33,15 @@ class SpikeRunner(Slate_Runner):
|
||||
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
||||
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
||||
|
||||
# Compute correlation matrix
|
||||
print("Computing correlation matrix")
|
||||
self.correlation_matrix = compute_correlation_matrix(self.train_data)
|
||||
print("Reconstructing thread topology")
|
||||
self.topology_matrix = compute_topology_metrics(self.train_data)
|
||||
|
||||
# Number of peers for message passing
|
||||
self.num_peers = slate.consume(config, 'middle_out.num_peers')
|
||||
|
||||
# Precompute sorted indices for the top num_peers correlated leads
|
||||
print("Precomputing sorted peer indices")
|
||||
self.sorted_peer_indices = np.argsort(-self.correlation_matrix, axis=1)[:, :self.num_peers]
|
||||
self.sorted_peer_indices = np.argsort(-self.topology_matrix, axis=1)[:, :self.num_peers]
|
||||
|
||||
# Model setup
|
||||
print("Setting up models")
|
||||
@ -111,7 +110,7 @@ class SpikeRunner(Slate_Runner):
|
||||
random.shuffle(indices)
|
||||
|
||||
stacked_segments = []
|
||||
peer_correlations = []
|
||||
peer_metrics = []
|
||||
targets = []
|
||||
|
||||
for idx in indices[:self.batch_size]:
|
||||
@ -128,8 +127,8 @@ class SpikeRunner(Slate_Runner):
|
||||
for peer_idx in self.sorted_peer_indices[idx]:
|
||||
peer_segment = self.train_data[peer_idx][i:i + self.input_size]
|
||||
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
||||
peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
|
||||
peer_correlations.append(peer_correlation)
|
||||
peer_metric = torch.tensor([self.topology_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
|
||||
peer_metrics.append(peer_metric)
|
||||
|
||||
# Stack the segments to form the batch
|
||||
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
||||
@ -144,7 +143,7 @@ class SpikeRunner(Slate_Runner):
|
||||
peer_latents = latents[:, 1:, :]
|
||||
|
||||
# Pass through MiddleOut
|
||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_correlations))
|
||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
prediction = self.predictor(new_latent)
|
||||
|
||||
# Calculate loss and backpropagate
|
||||
@ -209,21 +208,21 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
min_length = min([len(seq) for seq in self.test_data])
|
||||
|
||||
# Initialize lists to store segments and peer correlations
|
||||
# Initialize lists to store segments and peer metrics
|
||||
stacked_segments = []
|
||||
peer_correlations = []
|
||||
peer_metrics = []
|
||||
|
||||
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
|
||||
lead_segment = lead_data[i:i + self.input_size]
|
||||
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
||||
|
||||
# Collect peer segments and correlations
|
||||
# Collect peer segments and metrics
|
||||
peer_segments = []
|
||||
for peer_idx in self.sorted_peer_indices[lead_idx]:
|
||||
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
|
||||
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
||||
peer_correlation = torch.tensor([self.correlation_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
|
||||
peer_correlations.append(peer_correlation)
|
||||
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
|
||||
peer_metrics.append(peer_metric)
|
||||
|
||||
# Stack segments to form the batch
|
||||
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
||||
@ -238,7 +237,7 @@ class SpikeRunner(Slate_Runner):
|
||||
peer_latents = latents[:, 1:, :]
|
||||
|
||||
# Pass through MiddleOut
|
||||
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_correlations))
|
||||
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics))
|
||||
|
||||
# Predict using the predictor
|
||||
predictions = self.predictor(new_latents)
|
||||
|
@ -48,13 +48,13 @@ class MiddleOut(nn.Module):
|
||||
self.num_peers = num_peers
|
||||
self.fc = nn.Linear(latent_size * 2 + 1, output_size)
|
||||
|
||||
def forward(self, my_latent, peer_latents, peer_correlations):
|
||||
def forward(self, my_latent, peer_latents, peer_metrics):
|
||||
new_latents = []
|
||||
for p in range(peer_latents.shape[-2]):
|
||||
peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
|
||||
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
|
||||
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
||||
combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1)
|
||||
new_latent = self.fc(combined_input)
|
||||
new_latents.append(new_latent * correlation.unsqueeze(1))
|
||||
new_latents.append(new_latent * metric.unsqueeze(1))
|
||||
|
||||
new_latents = torch.stack(new_latents)
|
||||
averaged_latent = torch.mean(new_latents, dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user