Spikey/main.py

386 lines
18 KiB
Python
Raw Normal View History

2024-05-25 17:31:08 +02:00
import os
import torch
import torch.nn as nn
import numpy as np
2024-05-25 20:27:54 +02:00
import random, math
from utils import visualize_prediction, plot_delta_distribution
2024-05-28 12:53:33 +02:00
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
2024-05-27 17:00:02 +02:00
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
2024-05-28 12:53:33 +02:00
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
2024-05-25 17:31:08 +02:00
import wandb
2024-05-25 20:27:54 +02:00
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
from slate import Slate, Slate_Runner
2024-05-24 23:02:24 +02:00
2024-05-25 20:27:54 +02:00
class SpikeRunner(Slate_Runner):
def setup(self, name):
print("Setup SpikeRunner")
self.name = name
slate, config = self.slate, self.config
2024-05-24 22:01:59 +02:00
training_config = slate.consume(config, 'training', expand=True)
data_config = slate.consume(config, 'data', expand=True)
data_url = slate.consume(data_config, 'url')
2024-05-25 17:31:08 +02:00
cut_length = slate.consume(data_config, 'cut_length', None)
2024-05-25 17:44:12 +02:00
download_and_extract_data(data_url)
2024-05-27 17:00:02 +02:00
self.all_data = load_all_wavs('data', cut_length)
2024-05-25 17:31:08 +02:00
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
2024-05-27 17:00:02 +02:00
self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
2024-05-25 17:31:08 +02:00
2024-05-26 13:48:30 +02:00
print("Reconstructing thread topology")
self.topology_matrix = compute_topology_metrics(self.train_data)
2024-05-25 17:31:08 +02:00
2024-05-25 20:27:54 +02:00
# Number of peers for message passing
self.num_peers = slate.consume(config, 'middle_out.num_peers')
# Precompute sorted indices for the top num_peers correlated leads
print("Precomputing sorted peer indices")
2024-05-26 13:48:30 +02:00
self.sorted_peer_indices = np.argsort(-self.topology_matrix, axis=1)[:, :self.num_peers]
2024-05-25 20:27:54 +02:00
2024-05-25 17:31:08 +02:00
# Model setup
2024-05-25 20:27:54 +02:00
print("Setting up models")
2024-05-25 17:31:08 +02:00
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
2024-05-25 20:27:54 +02:00
latent_size = slate.consume(config, 'latent_projector.latent_size')
2024-05-28 12:53:33 +02:00
input_size = slate.consume(config, 'feature_extractor.input_size')
2024-05-26 13:56:59 +02:00
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
2024-05-29 21:12:07 +02:00
self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
2024-05-26 23:56:12 +02:00
device = slate.consume(training_config, 'device', 'auto')
2024-05-26 17:41:30 +02:00
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
2024-05-25 17:31:08 +02:00
2024-05-28 12:53:33 +02:00
self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
2024-05-27 17:00:02 +02:00
feature_size = self.feat.compute_output_size()
2024-05-25 17:31:08 +02:00
if latent_projector_type == 'fc':
2024-05-27 17:00:02 +02:00
self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
2024-05-25 17:31:08 +02:00
elif latent_projector_type == 'rnn':
2024-05-27 17:00:02 +02:00
self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
2024-05-27 10:28:51 +02:00
else:
raise Exception('No such Latent Projector')
2024-05-25 17:31:08 +02:00
2024-05-26 13:56:59 +02:00
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
2024-05-25 17:31:08 +02:00
# Training parameters
2024-05-25 20:27:54 +02:00
self.input_size = input_size
2024-05-24 23:02:24 +02:00
self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size')
2024-05-25 20:27:54 +02:00
self.num_batches = slate.consume(training_config, 'num_batches')
2024-05-24 23:02:24 +02:00
self.learning_rate = slate.consume(training_config, 'learning_rate')
2024-05-29 21:12:07 +02:00
self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
2024-05-25 17:31:08 +02:00
self.save_path = slate.consume(training_config, 'save_path')
2024-05-26 17:41:30 +02:00
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
2024-05-29 21:12:07 +02:00
self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
2024-05-24 23:02:24 +02:00
2024-05-25 17:31:08 +02:00
# Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
2024-05-24 22:01:59 +02:00
2024-05-25 17:31:08 +02:00
# Bitstream encoding
2024-05-28 12:53:33 +02:00
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
2024-05-25 17:31:08 +02:00
if bitstream_type == 'identity':
self.encoder = IdentityEncoder()
elif bitstream_type == 'arithmetic':
self.encoder = ArithmeticEncoder()
elif bitstream_type == 'bzip2':
self.encoder = Bzip2Encoder()
2024-05-27 10:28:51 +02:00
elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder()
2024-05-28 12:53:33 +02:00
elif bitstream_type == 'rice':
self.encoder = RiceEncoder()
2024-05-27 10:28:51 +02:00
else:
raise Exception('No such Encoder')
2024-05-29 21:12:07 +02:00
self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
2024-05-28 12:53:33 +02:00
2024-05-25 17:31:08 +02:00
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss()
2024-05-25 20:27:54 +02:00
print("SpikeRunner initialization complete")
2024-05-24 22:01:59 +02:00
def run(self, run, forceNoProfile=False):
2024-05-24 23:02:24 +02:00
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
print('{PROFILER RUNNING}')
with PyCallGraph(output=GraphvizOutput(output_file=f'./profiler/{self.name}.png')):
self.run(run, forceNoProfile=True)
print('{PROFILER DONE}')
return
2024-05-25 17:31:08 +02:00
self.train_model()
def train_model(self):
2024-05-26 17:41:30 +02:00
device = self.device
2024-05-25 20:27:54 +02:00
min_length = min([len(seq) for seq in self.train_data])
2024-05-26 17:41:30 +02:00
2024-05-25 17:31:08 +02:00
best_test_score = float('inf')
for epoch in range(self.epochs):
total_loss = 0
2024-05-26 00:28:18 +02:00
errs = []
rels = []
2024-05-26 23:56:12 +02:00
derrs = []
2024-05-25 20:27:54 +02:00
for batch_num in range(self.num_batches):
# Create indices for training data and shuffle them
indices = list(range(len(self.train_data)))
random.shuffle(indices)
stacked_segments = []
2024-05-26 13:48:30 +02:00
peer_metrics = []
2024-05-25 20:27:54 +02:00
targets = []
2024-05-26 23:56:12 +02:00
lasts = []
2024-05-25 20:27:54 +02:00
for idx in indices[:self.batch_size]:
lead_data = self.train_data[idx][:min_length]
# Slide a window over the data with overlap
2024-05-26 00:28:18 +02:00
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
2024-05-26 17:41:30 +02:00
offset = np.random.randint(0, stride)
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
2024-05-25 20:27:54 +02:00
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
# Collect the segments for the current lead and its peers
peer_segments = []
for peer_idx in self.sorted_peer_indices[idx]:
2024-05-26 00:28:18 +02:00
peer_segment = self.train_data[peer_idx][i:i + self.input_size]
2024-05-25 20:27:54 +02:00
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
2024-05-26 13:48:30 +02:00
peer_metric = torch.tensor([self.topology_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric)
2024-05-25 20:27:54 +02:00
# Stack the segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
2024-05-29 21:12:07 +02:00
target = lead_data[i + self.input_size]
2024-05-25 20:27:54 +02:00
targets.append(target)
2024-05-29 21:12:07 +02:00
last = lead_data[i + self.input_size - 1]
2024-05-26 23:56:12 +02:00
lasts.append(last)
2024-05-25 20:27:54 +02:00
2024-05-29 21:12:07 +02:00
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
2024-05-28 12:53:33 +02:00
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
2024-05-25 20:27:54 +02:00
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
2024-05-26 17:41:30 +02:00
# Scale gradients during backwards pass as configured
if self.peer_gradients_factor == 1.0:
pass
elif self.peer_gradients_factor == 0.0:
peer_latents = peer_latents.detach()
2024-05-26 17:41:30 +02:00
else:
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
2024-05-25 20:27:54 +02:00
# Pass through MiddleOut
2024-05-26 17:41:30 +02:00
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent)*self.value_scale
2024-05-25 20:27:54 +02:00
2024-05-29 21:12:07 +02:00
if self.delta_shift:
prediction = prediction + las
2024-05-25 20:27:54 +02:00
# Calculate loss and backpropagate
2024-05-26 00:28:18 +02:00
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
2024-05-29 21:12:07 +02:00
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
2024-05-26 00:28:18 +02:00
rel = err / np.sum(tar.cpu().detach().numpy())
2024-05-25 20:27:54 +02:00
total_loss += loss.item()
2024-05-26 23:56:12 +02:00
derrs.append(derr/np.prod(tar.size()).item())
2024-05-26 17:41:30 +02:00
errs.append(err/np.prod(tar.size()).item())
2024-05-26 00:28:18 +02:00
rels.append(rel.item())
2024-05-25 20:27:54 +02:00
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
2024-05-26 00:28:18 +02:00
tot_err = sum(errs)/len(errs)
2024-05-26 23:56:12 +02:00
tot_derr = sum(derrs)/len(derrs)
adv_delta = tot_derr / tot_err
2024-05-26 17:41:30 +02:00
approx_ratio = 1/(sum(rels)/len(rels))
2024-05-28 12:53:33 +02:00
wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
2024-05-25 20:27:54 +02:00
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
2024-05-25 17:31:08 +02:00
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
2024-05-25 20:27:54 +02:00
print(f'Evaluation complete for epoch {epoch + 1}')
2024-05-25 17:31:08 +02:00
def evaluate_model(self, epoch):
2024-05-25 20:27:54 +02:00
print('Evaluating model...')
2024-05-26 17:41:30 +02:00
device = self.device
2024-05-27 10:28:51 +02:00
# Save the current mode of the models
projector_mode = self.projector.training
middle_out_mode = self.middle_out.training
predictor_mode = self.predictor.training
# Set models to evaluation mode
2024-05-25 17:31:08 +02:00
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()
total_loss = 0
all_true = []
all_predicted = []
all_deltas = []
2024-05-29 21:12:07 +02:00
all_steps = []
2024-05-25 17:31:08 +02:00
with torch.no_grad():
2024-05-27 10:28:51 +02:00
min_length = min([len(seq) for seq in self.test_data])
errs = []
rels = []
derrs = []
2024-05-25 17:31:08 +02:00
2024-05-28 12:53:33 +02:00
indices = list(range(len(self.test_data)))
random.shuffle(indices)
2024-05-27 10:28:51 +02:00
2024-05-28 12:53:33 +02:00
for lead_idx in indices[:16]:
lead_data = self.test_data[lead_idx][:min_length]
2024-05-25 17:31:08 +02:00
2024-05-25 20:27:54 +02:00
stacked_segments = []
2024-05-26 13:48:30 +02:00
peer_metrics = []
2024-05-27 10:28:51 +02:00
targets = []
lasts = []
2024-05-25 20:27:54 +02:00
2024-05-27 10:28:51 +02:00
for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8):
2024-05-25 20:27:54 +02:00
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]:
2024-05-27 10:28:51 +02:00
peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size]
2024-05-25 20:27:54 +02:00
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
2024-05-26 13:48:30 +02:00
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric)
2024-05-25 17:31:08 +02:00
2024-05-25 20:27:54 +02:00
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
2024-05-29 21:12:07 +02:00
target = lead_data[i + self.input_size]
2024-05-25 20:27:54 +02:00
targets.append(target)
2024-05-29 21:12:07 +02:00
last = lead_data[i + self.input_size - 1]
2024-05-27 10:28:51 +02:00
lasts.append(last)
2024-05-25 17:31:08 +02:00
2024-05-29 21:12:07 +02:00
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
2024-05-28 12:53:33 +02:00
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
2024-05-25 17:31:08 +02:00
2024-05-27 10:28:51 +02:00
my_latent = latents[:, 0, :]
2024-05-25 20:27:54 +02:00
peer_latents = latents[:, 1:, :]
2024-05-25 17:31:08 +02:00
2024-05-27 10:28:51 +02:00
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent) * self.value_scale
2024-05-25 20:27:54 +02:00
2024-05-29 21:12:07 +02:00
if self.delta_shift:
prediction = prediction + las
2024-05-27 10:28:51 +02:00
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
loss = self.criterion(prediction, tar)
2024-05-29 21:12:07 +02:00
delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
err = np.sum(np.abs(delta))
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
2024-05-27 10:28:51 +02:00
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
derrs.append(derr / np.prod(tar.size()).item())
errs.append(err / np.prod(tar.size()).item())
rels.append(rel.item())
2024-05-25 20:27:54 +02:00
2024-05-27 10:28:51 +02:00
all_true.extend(tar.cpu().numpy())
all_predicted.extend(prediction.cpu().numpy())
2024-05-29 21:12:07 +02:00
all_deltas.extend(delta.tolist())
all_steps.extend(step.tolist())
2024-05-25 17:31:08 +02:00
2024-05-29 21:12:07 +02:00
if self.full_compression:
self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
ratio = raw_l / comp_l
wandb.log({"eval/ratio": ratio}, step=epoch)
2024-05-25 17:31:08 +02:00
avg_loss = total_loss / len(self.test_data)
2024-05-27 10:28:51 +02:00
tot_err = sum(errs) / len(errs)
tot_derr = sum(derrs) / len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1 / (sum(rels) / len(rels))
2024-05-25 17:31:08 +02:00
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
2024-05-28 12:53:33 +02:00
wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
2024-05-27 10:28:51 +02:00
# Visualize predictions
2024-05-28 12:53:33 +02:00
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
2024-05-29 21:12:07 +02:00
img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
2024-05-28 12:53:33 +02:00
try:
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
except:
pass
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
2024-05-25 17:31:08 +02:00
2024-05-27 10:28:51 +02:00
# Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
2024-05-28 12:53:33 +02:00
try:
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
except:
pass
#if self.full_compression:
# avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
# exact_match_percentage = (exact_matches / total_sequences) * 100
# print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
# print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
# wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
# wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
2024-05-25 17:31:08 +02:00
2024-05-27 10:28:51 +02:00
# Restore the original mode of the models
if projector_mode:
self.projector.train()
else:
self.projector.eval()
if middle_out_mode:
self.middle_out.train()
else:
self.middle_out.eval()
if predictor_mode:
self.predictor.train()
else:
self.predictor.eval()
2024-05-25 20:27:54 +02:00
print('Evaluation done for this epoch.')
2024-05-25 17:31:08 +02:00
return avg_loss
2024-05-24 23:02:24 +02:00
2024-05-25 17:31:08 +02:00
def save_models(self, epoch):
2024-05-25 20:43:45 +02:00
return
2024-05-25 20:27:54 +02:00
print('Saving models...')
2024-05-25 17:31:08 +02:00
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.")
2024-05-24 22:01:59 +02:00
2024-05-27 17:00:02 +02:00
def compress(raw):
threads = unfuckify_all(raw)
for thread in threads:
2024-05-28 12:53:33 +02:00
pass
2024-05-27 17:00:02 +02:00
# 1. featExtr
# 2. latentProj
# 3. middleOut
# 4. predictor
# 5. calc delta
# 6. encode
# 7. return
2024-05-24 22:01:59 +02:00
if __name__ == '__main__':
2024-05-25 20:27:54 +02:00
print('Initializing...')
2024-05-24 22:01:59 +02:00
slate = Slate({'spikey': SpikeRunner})
slate.from_args()
2024-05-25 20:27:54 +02:00
print('Done.')