Spikey/main.py

386 lines
18 KiB
Python

import os
import torch
import torch.nn as nn
import numpy as np
import random, math
from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
import wandb
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
from slate import Slate, Slate_Runner
class SpikeRunner(Slate_Runner):
def setup(self, name):
print("Setup SpikeRunner")
self.name = name
slate, config = self.slate, self.config
training_config = slate.consume(config, 'training', expand=True)
data_config = slate.consume(config, 'data', expand=True)
data_url = slate.consume(data_config, 'url')
cut_length = slate.consume(data_config, 'cut_length', None)
download_and_extract_data(data_url)
self.all_data = load_all_wavs('data', cut_length)
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
self.train_data, self.test_data = split_data_by_time(unfuckify_all(self.all_data), split_ratio)
print("Reconstructing thread topology")
self.topology_matrix = compute_topology_metrics(self.train_data)
# Number of peers for message passing
self.num_peers = slate.consume(config, 'middle_out.num_peers')
# Precompute sorted indices for the top num_peers correlated leads
print("Precomputing sorted peer indices")
self.sorted_peer_indices = np.argsort(-self.topology_matrix, axis=1)[:, :self.num_peers]
# Model setup
print("Setting up models")
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'feature_extractor.input_size')
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
device = slate.consume(training_config, 'device', 'auto')
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
feature_size = self.feat.compute_output_size()
if latent_projector_type == 'fc':
self.projector = LatentFCProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
elif latent_projector_type == 'rnn':
self.projector = LatentRNNProjector(latent_size=latent_size, feature_size=feature_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
else:
raise Exception('No such Latent Projector')
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
# Training parameters
self.input_size = input_size
self.epochs = slate.consume(training_config, 'epochs')
self.batch_size = slate.consume(training_config, 'batch_size')
self.num_batches = slate.consume(training_config, 'num_batches')
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
self.save_path = slate.consume(training_config, 'save_path')
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
# Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
# Bitstream encoding
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
if bitstream_type == 'identity':
self.encoder = IdentityEncoder()
elif bitstream_type == 'arithmetic':
self.encoder = ArithmeticEncoder()
elif bitstream_type == 'bzip2':
self.encoder = Bzip2Encoder()
elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder()
elif bitstream_type == 'rice':
self.encoder = RiceEncoder()
else:
raise Exception('No such Encoder')
self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss()
print("SpikeRunner initialization complete")
def run(self, run, forceNoProfile=False):
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
print('{PROFILER RUNNING}')
with PyCallGraph(output=GraphvizOutput(output_file=f'./profiler/{self.name}.png')):
self.run(run, forceNoProfile=True)
print('{PROFILER DONE}')
return
self.train_model()
def train_model(self):
device = self.device
min_length = min([len(seq) for seq in self.train_data])
best_test_score = float('inf')
for epoch in range(self.epochs):
total_loss = 0
errs = []
rels = []
derrs = []
for batch_num in range(self.num_batches):
# Create indices for training data and shuffle them
indices = list(range(len(self.train_data)))
random.shuffle(indices)
stacked_segments = []
peer_metrics = []
targets = []
lasts = []
for idx in indices[:self.batch_size]:
lead_data = self.train_data[idx][:min_length]
# Slide a window over the data with overlap
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
offset = np.random.randint(0, stride)
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
# Collect the segments for the current lead and its peers
peer_segments = []
for peer_idx in self.sorted_peer_indices[idx]:
peer_segment = self.train_data[peer_idx][i:i + self.input_size]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_metric = torch.tensor([self.topology_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric)
# Stack the segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size]
targets.append(target)
last = lead_data[i + self.input_size - 1]
lasts.append(last)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
# Scale gradients during backwards pass as configured
if self.peer_gradients_factor == 1.0:
pass
elif self.peer_gradients_factor == 0.0:
peer_latents = peer_latents.detach()
else:
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
# Pass through MiddleOut
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent)*self.value_scale
if self.delta_shift:
prediction = prediction + las
# Calculate loss and backpropagate
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
derrs.append(derr/np.prod(tar.size()).item())
errs.append(err/np.prod(tar.size()).item())
rels.append(rel.item())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
tot_err = sum(errs)/len(errs)
tot_derr = sum(derrs)/len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1/(sum(rels)/len(rels))
wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
print(f'Starting evaluation for epoch {epoch + 1}')
test_loss = self.evaluate_model(epoch)
if test_loss < best_test_score:
best_test_score = test_loss
self.save_models(epoch)
print(f'Evaluation complete for epoch {epoch + 1}')
def evaluate_model(self, epoch):
print('Evaluating model...')
device = self.device
# Save the current mode of the models
projector_mode = self.projector.training
middle_out_mode = self.middle_out.training
predictor_mode = self.predictor.training
# Set models to evaluation mode
self.projector.eval()
self.middle_out.eval()
self.predictor.eval()
total_loss = 0
all_true = []
all_predicted = []
all_deltas = []
all_steps = []
with torch.no_grad():
min_length = min([len(seq) for seq in self.test_data])
errs = []
rels = []
derrs = []
indices = list(range(len(self.test_data)))
random.shuffle(indices)
for lead_idx in indices[:16]:
lead_data = self.test_data[lead_idx][:min_length]
stacked_segments = []
peer_metrics = []
targets = []
lasts = []
for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8):
lead_segment = lead_data[i:i + self.input_size]
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
peer_segments = []
for peer_idx in self.sorted_peer_indices[lead_idx]:
peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size]
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
peer_metrics.append(peer_metric)
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size]
targets.append(target)
last = lead_data[i + self.input_size - 1]
lasts.append(last)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent) * self.value_scale
if self.delta_shift:
prediction = prediction + las
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
loss = self.criterion(prediction, tar)
delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
err = np.sum(np.abs(delta))
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item()
derrs.append(derr / np.prod(tar.size()).item())
errs.append(err / np.prod(tar.size()).item())
rels.append(rel.item())
all_true.extend(tar.cpu().numpy())
all_predicted.extend(prediction.cpu().numpy())
all_deltas.extend(delta.tolist())
all_steps.extend(step.tolist())
if self.full_compression:
self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
ratio = raw_l / comp_l
wandb.log({"eval/ratio": ratio}, step=epoch)
avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs)
tot_derr = sum(derrs) / len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1 / (sum(rels) / len(rels))
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
# Visualize predictions
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
try:
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
except:
pass
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
# Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
try:
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
except:
pass
#if self.full_compression:
# avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
# exact_match_percentage = (exact_matches / total_sequences) * 100
# print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
# print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
# wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
# wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
# Restore the original mode of the models
if projector_mode:
self.projector.train()
else:
self.projector.eval()
if middle_out_mode:
self.middle_out.train()
else:
self.middle_out.eval()
if predictor_mode:
self.predictor.train()
else:
self.predictor.eval()
print('Evaluation done for this epoch.')
return avg_loss
def save_models(self, epoch):
return
print('Saving models...')
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
print(f"New high score! Models saved at epoch {epoch+1}.")
def compress(raw):
threads = unfuckify_all(raw)
for thread in threads:
pass
# 1. featExtr
# 2. latentProj
# 3. middleOut
# 4. predictor
# 5. calc delta
# 6. encode
# 7. return
if __name__ == '__main__':
print('Initializing...')
slate = Slate({'spikey': SpikeRunner})
slate.from_args()
print('Done.')