README typos and smol refactor

This commit is contained in:
Dominik Moritz Roth 2024-05-26 13:56:59 +02:00
parent fb8dbd30ee
commit ba1caf7d80
4 changed files with 23 additions and 21 deletions

View File

@ -4,41 +4,43 @@ This repository contains a solution for the [Neuralink Compression Challenge](ht
## Challenge Overview ## Challenge Overview
The Neuralink N1 implant generates approximately 200Mbps of electrode data (1024 electrodes @ 20kHz, 10-bit resolution) and can transmit data wirelessly at about 1Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1ms) and consume low power (< 10mW, including radio). The Neuralink N1 implant generates approximately 200 Mbps of electrode data (1024 electrodes @ 20 kHz, 10-bit resolution) and can transmit data wirelessly at about 1 Mbps. This means a compression ratio of over 200x is required. The compression must run in real-time (< 1 ms) and consume low power (< 10 mW, including radio).
## Data Analysis ## Data Analysis
The `analysis.ipynb` notebook contains a detailed analysis of the data. We found that there is sometimes significant cross-correlation between the different leads, so we find it vital to use this information for better compression. This cross-correlation allows us to improve the accuracy of our predictions and reduce the overall amount of data that needs to be transmitted. As part of the analysis, we also note that achieving a 200x compression ratio is highly unlikely to be possible and is also nonsensical, a very close reproduction is sufficient. The `analysis.ipynb` notebook contains a detailed analysis of the data. We found that there is sometimes significant cross-correlation between the different leads, so we find it vital to use this information for better compression. This cross-correlation allows us to improve the accuracy of our predictions and reduce the overall amount of data that needs to be transmitted. As part of the analysis, we also note that achieving a 200x compression ratio is highly unlikely to be possible and is also nonsensical; a very close reproduction is sufficient.
## Algorithm Overview ## Algorithm Overview
### 1 - Thread Topology Reconstruction ### 1 - Thread Topology Reconstruction
As the first step we analyse reading from the leads to construct an approximative topology of the threads in the brain. The distance metric we generate only approximately represents true euclidean distances, but rather the 'distance' in common activity. This topology must only be computed once for a given implant and maybe updated fro thread movements, but is not part of the regular compression/decompression process. As the first step, we analyze readings from the leads to construct an approximate topology of the threads in the brain. The distance metric we generate only approximately represents true Euclidean distances, but rather the 'distance' in common activity. This topology must only be computed once for a given implant and may be updated for thread movements but is not part of the regular compression/decompression process.
### 2 - Predictive Architecture ### 2 - Predictive Architecture
The main workhorse of our compression approach is a predictive model running both in the compressor and decompressor. With good predictions of the data, only the error between the prediction and actual data nmust be transmitted. We make use of the previously constructed topology to allow the predictive model's latent to represent activity of brain regions based on the reading of the threads instead of just for threads themselves. The main workhorse of our compression approach is a predictive model running both in the compressor and decompressor. With good predictions of the data, only the error between the prediction and actual data must be transmitted. We make use of the previously constructed topology to allow the predictive model's latent to represent the activity of brain regions based on the reading of the threads instead of just for threads themselves.
The solution leverages three neural network models to achieve effective compression: The solution leverages three neural network models to achieve effective compression:
1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) with arbitrary shape. 1. **Latent Projector**: This module takes in a segment of a lead and projects it into a latent space. The latent projector can be configured as a fully connected network or an RNN (LSTM) with an arbitrary shape.
2. **MiddleOut (Message Passer)**: For each lead, this module perfroms message passing according to the thread topology. Their latent representations along with their distance metrics are used to generate joint latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, metcric) -> joint_latent and then averaging over all joint_latent values to get the final representation. 2. **MiddleOut (Message Passer)**: For each lead, this module performs message passing according to the thread topology. Their latent representations along with their distance metrics are used to generate region latent representation. This is done by training a fully connected layer to map from (our_latent, their_latent, metric) -> region_latent and then averaging over all region_latent values to get the final representation.
3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training. Can be configured to be an FCNN of arbitrary shape. 3. **Predictor**: This module takes the new latent representation from the MiddleOut module and predicts the next timestep. The goal is to minimize the prediction error during training. It can be configured to be an FCNN of arbitrary shape.
The neural networks used in this solution are rather small, making it possible to meet the latency and power requirements if implemented more efficiently. The neural networks used in this solution are rather small, making it possible to meet the latency and power requirements if implemented more efficiently.
If we were to give up on lossless compression, one could expand MiddleOut to form a joint latent over all threads and transmit that.
### 3 - Efficient Bitstream Encoding ### 3 - Efficient Bitstream Encoding
Based on an expected distribution of deltas, that have to be transmitted an efficient huffman-like binary format is used for encoding of the data. Based on an expected distribution of deltas that have to be transmitted, an efficient Huffman-like binary format is used for encoding the data.
## TODO ## TODO
- All currently implemented bitstream encoders are rather naive. We know, that lead values from the N1 only have 10 bit precision, but wav file provides yus with 32bit floats. All my bitstream encoders are also based on 32bit floats, discretizing back into the 10 bit space would be a low hanging fruit for ~3.2x compression. - All currently implemented bitstream encoders are rather naive. We know that lead values from the N1 only have 10-bit precision, but the WAV file provides us with 32-bit floats. All my bitstream encoders are also based on 32-bit floats; discretizing back into the 10-bit space would be a low-hanging fruit for ~3.2x compression.
- Since we merely encode the remaining delta, we can go even more efficient by constructing something along the lines of a huffman tree. - Since we merely encode the remaining delta, we can go even more efficient by constructing something along the lines of a Huffman tree.
- Loss is not coming down during training... So basicaly nothing works right now. But the text I wrote is cool, right? - Loss is not coming down during training... So basically nothing works right now. But the text I wrote is cool, right?
- Make a logo - Make a logo
## Installation ## Installation

View File

@ -69,7 +69,7 @@ latent_projector:
#rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn'). #rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn').
middle_out: middle_out:
output_size: 4 # Size of the latent representation after message passing. region_latent_size: 4 # Size of the latent representation after message passing.
num_peers: 3 # Number of most correlated peers to consider. num_peers: 3 # Number of most correlated peers to consider.
predictor: predictor:
@ -95,7 +95,7 @@ latent_projector:
#rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn'). #rnn_num_layers: 1 # Number of layers for the RNN projector (if type is 'rnn').
middle_out: middle_out:
output_size: 8 # Size of the latent representation after message passing. region_latent_size: 8 # Size of the latent representation after message passing.
num_peers: 3 # Number of most correlated peers to consider. num_peers: 3 # Number of most correlated peers to consider.
predictor: predictor:
@ -121,7 +121,7 @@ latent_projector:
rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn'). rnn_num_layers: 2 # Number of layers for the RNN projector (if type is 'rnn').
middle_out: middle_out:
output_size: 4 # Size of the latent representation after message passing. region_latent_size: 4 # Size of the latent representation after message passing.
num_peers: 3 # Number of most correlated peers to consider. num_peers: 3 # Number of most correlated peers to consider.
predictor: predictor:

View File

@ -48,15 +48,15 @@ class SpikeRunner(Slate_Runner):
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc') latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size') latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size') input_size = slate.consume(config, 'latent_projector.input_size')
output_size = slate.consume(config, 'middle_out.output_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
if latent_projector_type == 'fc': if latent_projector_type == 'fc':
self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn': elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
self.middle_out = MiddleOut(latent_size=latent_size, output_size=output_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(output_size=output_size, **slate.consume(config, 'predictor', expand=True)).to(device) self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
# Training parameters # Training parameters
self.input_size = input_size self.input_size = input_size

View File

@ -43,10 +43,10 @@ class LatentRNNProjector(nn.Module):
return latent return latent
class MiddleOut(nn.Module): class MiddleOut(nn.Module):
def __init__(self, latent_size, output_size, num_peers): def __init__(self, latent_size, region_latent_size, num_peers):
super(MiddleOut, self).__init__() super(MiddleOut, self).__init__()
self.num_peers = num_peers self.num_peers = num_peers
self.fc = nn.Linear(latent_size * 2 + 1, output_size) self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
def forward(self, my_latent, peer_latents, peer_metrics): def forward(self, my_latent, peer_latents, peer_metrics):
new_latents = [] new_latents = []
@ -61,10 +61,10 @@ class MiddleOut(nn.Module):
return averaged_latent return averaged_latent
class Predictor(nn.Module): class Predictor(nn.Module):
def __init__(self, output_size, layer_shapes, activations): def __init__(self, region_latent_size, layer_shapes, activations):
super(Predictor, self).__init__() super(Predictor, self).__init__()
layers = [] layers = []
in_features = output_size in_features = region_latent_size
for i, out_features in enumerate(layer_shapes): for i, out_features in enumerate(layer_shapes):
layers.append(nn.Linear(in_features, out_features)) layers.append(nn.Linear(in_features, out_features))
if activations[i] != 'None': if activations[i] != 'None':