323 lines
15 KiB
Python
323 lines
15 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import random, math
|
|
from utils import visualize_prediction, plot_delta_distribution
|
|
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics
|
|
from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor
|
|
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder
|
|
import wandb
|
|
from pycallgraph2 import PyCallGraph
|
|
from pycallgraph2.output import GraphvizOutput
|
|
from slate import Slate, Slate_Runner
|
|
|
|
|
|
class SpikeRunner(Slate_Runner):
|
|
def setup(self, name):
|
|
print("Setup SpikeRunner")
|
|
|
|
self.name = name
|
|
slate, config = self.slate, self.config
|
|
|
|
training_config = slate.consume(config, 'training', expand=True)
|
|
data_config = slate.consume(config, 'data', expand=True)
|
|
|
|
data_url = slate.consume(data_config, 'url')
|
|
cut_length = slate.consume(data_config, 'cut_length', None)
|
|
download_and_extract_data(data_url)
|
|
all_data = load_all_wavs('data', cut_length)
|
|
|
|
split_ratio = slate.consume(data_config, 'split_ratio', 0.5)
|
|
self.train_data, self.test_data = split_data_by_time(all_data, split_ratio)
|
|
|
|
print("Reconstructing thread topology")
|
|
self.topology_matrix = compute_topology_metrics(self.train_data)
|
|
|
|
# Number of peers for message passing
|
|
self.num_peers = slate.consume(config, 'middle_out.num_peers')
|
|
|
|
# Precompute sorted indices for the top num_peers correlated leads
|
|
print("Precomputing sorted peer indices")
|
|
self.sorted_peer_indices = np.argsort(-self.topology_matrix, axis=1)[:, :self.num_peers]
|
|
|
|
# Model setup
|
|
print("Setting up models")
|
|
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
|
|
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
|
input_size = slate.consume(config, 'latent_projector.input_size')
|
|
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
|
device = slate.consume(training_config, 'device')
|
|
if device == 'auto':
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.device = device
|
|
|
|
if latent_projector_type == 'fc':
|
|
self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
|
elif latent_projector_type == 'rnn':
|
|
self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
|
elif latent_projector_type == 'fourier':
|
|
self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device)
|
|
|
|
self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device)
|
|
self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device)
|
|
|
|
# Training parameters
|
|
self.input_size = input_size
|
|
self.epochs = slate.consume(training_config, 'epochs')
|
|
self.batch_size = slate.consume(training_config, 'batch_size')
|
|
self.num_batches = slate.consume(training_config, 'num_batches')
|
|
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
|
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
|
self.save_path = slate.consume(training_config, 'save_path')
|
|
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
|
|
self.value_scale = slate.consume(training_config, 'value_scale')
|
|
|
|
# Evaluation parameter
|
|
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
|
|
|
# Bitstream encoding
|
|
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
|
|
if bitstream_type == 'identity':
|
|
self.encoder = IdentityEncoder()
|
|
elif bitstream_type == 'arithmetic':
|
|
self.encoder = ArithmeticEncoder()
|
|
elif bitstream_type == 'bzip2':
|
|
self.encoder = Bzip2Encoder()
|
|
|
|
# Optimizer
|
|
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
|
self.criterion = torch.nn.MSELoss()
|
|
print("SpikeRunner initialization complete")
|
|
|
|
def run(self, run, forceNoProfile=False):
|
|
if self.slate.consume(self.config, 'profiler.enable', False) and not forceNoProfile:
|
|
print('{PROFILER RUNNING}')
|
|
with PyCallGraph(output=GraphvizOutput(output_file=f'./profiler/{self.name}.png')):
|
|
self.run(run, forceNoProfile=True)
|
|
print('{PROFILER DONE}')
|
|
return
|
|
|
|
self.train_model()
|
|
|
|
def train_model(self):
|
|
device = self.device
|
|
min_length = min([len(seq) for seq in self.train_data])
|
|
|
|
best_test_score = float('inf')
|
|
|
|
for epoch in range(self.epochs):
|
|
total_loss = 0
|
|
errs = []
|
|
rels = []
|
|
for batch_num in range(self.num_batches):
|
|
|
|
# Create indices for training data and shuffle them
|
|
indices = list(range(len(self.train_data)))
|
|
random.shuffle(indices)
|
|
|
|
stacked_segments = []
|
|
peer_metrics = []
|
|
targets = []
|
|
|
|
for idx in indices[:self.batch_size]:
|
|
lead_data = self.train_data[idx][:min_length]
|
|
|
|
# Slide a window over the data with overlap
|
|
stride = max(1, self.input_size // 3) # Ensuring stride is at least 1
|
|
offset = np.random.randint(0, stride)
|
|
for i in range(offset, len(lead_data) - self.input_size-1-offset, stride):
|
|
lead_segment = lead_data[i:i + self.input_size]
|
|
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
|
|
|
# Collect the segments for the current lead and its peers
|
|
peer_segments = []
|
|
for peer_idx in self.sorted_peer_indices[idx]:
|
|
peer_segment = self.train_data[peer_idx][i:i + self.input_size]
|
|
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
|
peer_metric = torch.tensor([self.topology_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device)
|
|
peer_metrics.append(peer_metric)
|
|
|
|
# Stack the segments to form the batch
|
|
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
|
stacked_segments.append(stacked_segment)
|
|
target = lead_data[i + self.input_size + 1]
|
|
targets.append(target)
|
|
|
|
# Pass the batch through the projector
|
|
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
|
|
|
my_latent = latents[:, 0, :]
|
|
peer_latents = latents[:, 1:, :]
|
|
|
|
# Scale gradients during backwards pass as configured
|
|
if self.peer_gradients_factor == 1.0:
|
|
pass
|
|
elif self.peer_gradients_factor == 0.0:
|
|
peer_latents = peer_latents.detach()
|
|
else:
|
|
peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor)
|
|
|
|
# Pass through MiddleOut
|
|
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
|
prediction = self.predictor(region_latent)*self.value_scale
|
|
|
|
# Calculate loss and backpropagate
|
|
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
|
|
loss = self.criterion(prediction, tar)
|
|
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
|
|
rel = err / np.sum(tar.cpu().detach().numpy())
|
|
total_loss += loss.item()
|
|
errs.append(err/np.prod(tar.size()).item())
|
|
rels.append(rel.item())
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
tot_err = sum(errs)/len(errs)
|
|
approx_ratio = 1/(sum(rels)/len(rels))
|
|
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch)
|
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
|
|
|
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
|
print(f'Starting evaluation for epoch {epoch + 1}')
|
|
test_loss = self.evaluate_model(epoch)
|
|
if test_loss < best_test_score:
|
|
best_test_score = test_loss
|
|
self.save_models(epoch)
|
|
print(f'Evaluation complete for epoch {epoch + 1}')
|
|
|
|
|
|
wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch)
|
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
|
|
|
if (epoch + 1) % self.eval_freq == 0:
|
|
print(f'Starting evaluation for epoch {epoch + 1}')
|
|
test_loss = self.evaluate_model(epoch)
|
|
if test_loss < best_test_score:
|
|
best_test_score = test_loss
|
|
self.save_models(epoch)
|
|
print(f'Evaluation complete for epoch {epoch + 1}')
|
|
|
|
|
|
def evaluate_model(self, epoch):
|
|
print('Evaluating model...')
|
|
device = self.device
|
|
|
|
self.projector.eval()
|
|
self.middle_out.eval()
|
|
self.predictor.eval()
|
|
|
|
total_loss = 0
|
|
all_true = []
|
|
all_predicted = []
|
|
all_deltas = []
|
|
compression_ratios = []
|
|
exact_matches = 0
|
|
total_sequences = 0
|
|
|
|
with torch.no_grad():
|
|
for lead_idx in range(len(self.test_data[:8])):
|
|
lead_data = self.test_data[lead_idx]
|
|
true_data = []
|
|
predicted_data = []
|
|
delta_data = []
|
|
targets = []
|
|
|
|
min_length = min([len(seq) for seq in self.test_data])
|
|
|
|
# Initialize lists to store segments and peer metrics
|
|
stacked_segments = []
|
|
peer_metrics = []
|
|
|
|
for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8):
|
|
lead_segment = lead_data[i:i + self.input_size]
|
|
inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device)
|
|
|
|
# Collect peer segments and metrics
|
|
peer_segments = []
|
|
for peer_idx in self.sorted_peer_indices[lead_idx]:
|
|
peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length]
|
|
peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device))
|
|
peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device)
|
|
peer_metrics.append(peer_metric)
|
|
|
|
# Stack segments to form the batch
|
|
stacked_segment = torch.stack([inputs] + peer_segments).to(device)
|
|
stacked_segments.append(stacked_segment)
|
|
target = lead_data[i + self.input_size + 1]
|
|
targets.append(target)
|
|
|
|
# Pass the batch through the projector
|
|
latents = self.projector(torch.stack(stacked_segments))
|
|
|
|
my_latents = latents[:, 0, :]
|
|
peer_latents = latents[:, 1:, :]
|
|
|
|
# Pass through MiddleOut
|
|
new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics))
|
|
|
|
# Predict using the predictor
|
|
predictions = self.predictor(new_latents)
|
|
|
|
# Compute loss and store true and predicted data
|
|
for i, segment in enumerate(stacked_segments):
|
|
for t in range(self.input_size):
|
|
target = torch.tensor(targets[i])
|
|
true_data.append(target.cpu().numpy())
|
|
predicted_data.append(predictions[i].cpu().numpy())
|
|
delta_data.append((target - predictions[i]).cpu().numpy())
|
|
|
|
loss = self.criterion(predictions[i].cpu(), target)
|
|
total_loss += loss.item()
|
|
|
|
# Append true and predicted data for this lead sequence
|
|
all_true.append(true_data)
|
|
all_predicted.append(predicted_data)
|
|
all_deltas.append(delta_data)
|
|
|
|
if self.full_compression:
|
|
# Bitstream encoding
|
|
self.encoder.build_model(my_latents.cpu().numpy())
|
|
compressed_data = self.encoder.encode(my_latents.cpu().numpy())
|
|
decompressed_data = self.encoder.decode(compressed_data, len(my_latents))
|
|
compression_ratio = len(my_latents) / len(compressed_data)
|
|
compression_ratios.append(compression_ratio)
|
|
|
|
# Check if decompressed data matches the original data
|
|
if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5):
|
|
exact_matches += 1
|
|
total_sequences += 1
|
|
|
|
avg_loss = total_loss / len(self.test_data)
|
|
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
|
wandb.log({"evaluation_loss": avg_loss}, step=epoch)
|
|
|
|
# Visualize delta distribution
|
|
delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch)
|
|
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
|
|
|
if self.full_compression:
|
|
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
|
exact_match_percentage = (exact_matches / total_sequences) * 100
|
|
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
|
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
|
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
|
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
|
|
|
print('Evaluation done for this epoch.')
|
|
return avg_loss
|
|
|
|
def save_models(self, epoch):
|
|
return
|
|
print('Saving models...')
|
|
torch.save(self.projector.state_dict(), os.path.join(self.save_path, f"best_projector_epoch_{epoch+1}.pt"))
|
|
torch.save(self.middle_out.state_dict(), os.path.join(self.save_path, f"best_middle_out_epoch_{epoch+1}.pt"))
|
|
torch.save(self.predictor.state_dict(), os.path.join(self.save_path, f"best_predictor_epoch_{epoch+1}.pt"))
|
|
print(f"New high score! Models saved at epoch {epoch+1}.")
|
|
|
|
if __name__ == '__main__':
|
|
print('Initializing...')
|
|
slate = Slate({'spikey': SpikeRunner})
|
|
slate.from_args()
|
|
print('Done.') |