diff --git a/main.py b/main.py index 00b425a..766ae98 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ import random, math from utils import visualize_prediction, plot_delta_distribution from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics from models import LatentFCProjector, LatentRNNProjector, LatentFourierProjector,MiddleOut, Predictor -from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder +from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder import wandb from pycallgraph2 import PyCallGraph from pycallgraph2.output import GraphvizOutput @@ -58,6 +58,8 @@ class SpikeRunner(Slate_Runner): self.projector = LatentRNNProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) elif latent_projector_type == 'fourier': self.projector = LatentFourierProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) + else: + raise Exception('No such Latent Projector') self.middle_out = MiddleOut(latent_size=latent_size, region_latent_size=region_latent_size, num_peers=self.num_peers, **slate.consume(config, 'middle_out', expand=True)).to(device) self.predictor = Predictor(region_latent_size=region_latent_size, **slate.consume(config, 'predictor', expand=True)).to(device) @@ -84,6 +86,10 @@ class SpikeRunner(Slate_Runner): self.encoder = ArithmeticEncoder() elif bitstream_type == 'bzip2': self.encoder = Bzip2Encoder() + elif bitstream_type == 'binomHuffman': + self.encoder = BinomialHuffmanEncoder() + else: + raise Exception('No such Encoder') # Optimizer self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) @@ -196,23 +202,16 @@ class SpikeRunner(Slate_Runner): self.save_models(epoch) print(f'Evaluation complete for epoch {epoch + 1}') - - wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch) - print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') - - if (epoch + 1) % self.eval_freq == 0: - print(f'Starting evaluation for epoch {epoch + 1}') - test_loss = self.evaluate_model(epoch) - if test_loss < best_test_score: - best_test_score = test_loss - self.save_models(epoch) - print(f'Evaluation complete for epoch {epoch + 1}') - - def evaluate_model(self, epoch): print('Evaluating model...') device = self.device + # Save the current mode of the models + projector_mode = self.projector.training + middle_out_mode = self.middle_out.training + predictor_mode = self.predictor.training + + # Set models to evaluation mode self.projector.eval() self.middle_out.eval() self.predictor.eval() @@ -226,84 +225,91 @@ class SpikeRunner(Slate_Runner): total_sequences = 0 with torch.no_grad(): - for lead_idx in range(len(self.test_data[:8])): - lead_data = self.test_data[lead_idx] - true_data = [] - predicted_data = [] - delta_data = [] - targets = [] + min_length = min([len(seq) for seq in self.test_data]) - min_length = min([len(seq) for seq in self.test_data]) + errs = [] + rels = [] + derrs = [] + + for lead_idx in range(len(self.test_data)): + lead_data = self.test_data[lead_idx][:min_length] + + indices = list(range(len(self.test_data))) + random.shuffle(indices) - # Initialize lists to store segments and peer metrics stacked_segments = [] peer_metrics = [] + targets = [] + lasts = [] - for i in range(0, len(lead_data) - self.input_size-1, self.input_size // 8): + for i in range(0, len(lead_data) - self.input_size - 1, self.input_size // 8): lead_segment = lead_data[i:i + self.input_size] inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) - # Collect peer segments and metrics peer_segments = [] for peer_idx in self.sorted_peer_indices[lead_idx]: - peer_segment = self.test_data[peer_idx][i:i + self.input_size][:min_length] + peer_segment = self.test_data[peer_idx][:min_length][i:i + self.input_size] peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device)) peer_metric = torch.tensor([self.topology_matrix[lead_idx, peer_idx] for peer_idx in self.sorted_peer_indices[lead_idx]], dtype=torch.float32).to(device) peer_metrics.append(peer_metric) - # Stack segments to form the batch stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segments.append(stacked_segment) target = lead_data[i + self.input_size + 1] targets.append(target) + last = lead_data[i + self.input_size] + lasts.append(last) - # Pass the batch through the projector - latents = self.projector(torch.stack(stacked_segments)) + latents = self.projector(torch.stack(stacked_segments) / self.value_scale) - my_latents = latents[:, 0, :] + my_latent = latents[:, 0, :] peer_latents = latents[:, 1:, :] - # Pass through MiddleOut - new_latents = self.middle_out(my_latents, peer_latents, torch.stack(peer_metrics)) + region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) + prediction = self.predictor(region_latent) * self.value_scale - # Predict using the predictor - predictions = self.predictor(new_latents) + tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) + las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy() + loss = self.criterion(prediction, tar) + err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) + derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) + rel = err / np.sum(tar.cpu().detach().numpy()) + total_loss += loss.item() + derrs.append(derr / np.prod(tar.size()).item()) + errs.append(err / np.prod(tar.size()).item()) + rels.append(rel.item()) - # Compute loss and store true and predicted data - for i, segment in enumerate(stacked_segments): - for t in range(self.input_size): - target = torch.tensor(targets[i]) - true_data.append(target.cpu().numpy()) - predicted_data.append(predictions[i].cpu().numpy()) - delta_data.append((target - predictions[i]).cpu().numpy()) - - loss = self.criterion(predictions[i].cpu(), target) - total_loss += loss.item() - - # Append true and predicted data for this lead sequence - all_true.append(true_data) - all_predicted.append(predicted_data) - all_deltas.append(delta_data) + all_true.extend(tar.cpu().numpy()) + all_predicted.extend(prediction.cpu().numpy()) + all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) if self.full_compression: - # Bitstream encoding - self.encoder.build_model(my_latents.cpu().numpy()) - compressed_data = self.encoder.encode(my_latents.cpu().numpy()) - decompressed_data = self.encoder.decode(compressed_data, len(my_latents)) - compression_ratio = len(my_latents) / len(compressed_data) + self.encoder.build_model(my_latent.cpu().numpy()) + compressed_data = self.encoder.encode(my_latent.cpu().numpy()) + decompressed_data = self.encoder.decode(compressed_data, len(my_latent)) + compression_ratio = len(my_latent) / len(compressed_data) compression_ratios.append(compression_ratio) - # Check if decompressed data matches the original data - if np.allclose(my_latents.cpu().numpy(), decompressed_data, atol=1e-5): + if np.allclose(my_latent.cpu().numpy(), decompressed_data, atol=1e-5): exact_matches += 1 total_sequences += 1 avg_loss = total_loss / len(self.test_data) - print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') - wandb.log({"evaluation_loss": avg_loss}, step=epoch) + tot_err = sum(errs) / len(errs) + tot_derr = sum(derrs) / len(derrs) + adv_delta = tot_derr / tot_err + approx_ratio = 1 / (sum(rels) / len(rels)) - # Visualize delta distribution - delta_plot_path = plot_delta_distribution(np.concatenate(all_deltas), epoch) + print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') + wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) + + # Visualize predictions + visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') + visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s') + visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s') + + # Plot delta distribution + delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch) wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) if self.full_compression: @@ -314,6 +320,22 @@ class SpikeRunner(Slate_Runner): wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) + # Restore the original mode of the models + if projector_mode: + self.projector.train() + else: + self.projector.eval() + + if middle_out_mode: + self.middle_out.train() + else: + self.middle_out.eval() + + if predictor_mode: + self.predictor.train() + else: + self.predictor.eval() + print('Evaluation done for this epoch.') return avg_loss diff --git a/utils.py b/utils.py index 49cf63e..f8f7f93 100644 --- a/utils.py +++ b/utils.py @@ -14,39 +14,49 @@ def visualize_wav_data(sample_rate, data, title="WAV Data", num_points=None): plt.ylabel('Amplitude') plt.show() -def visualize_prediction(true_data, predicted_data, delta_data, sample_rate, num_points=None, epoch=None): - """Visualize the true data, predicted data, and deltas.""" +def visualize_prediction(true_data, predicted_data, delta_data, num_points=None, epoch=None, name=''): + """Visualize the true data, predicted data, deltas, and combined plot.""" if num_points: true_data = true_data[:num_points] predicted_data = predicted_data[:num_points] delta_data = delta_data[:num_points] - plt.figure(figsize=(15, 5)) + plt.figure(figsize=(20, 5)) - plt.subplot(3, 1, 1) + plt.subplot(2, 2, 1) plt.plot(true_data, label='True Data') plt.title('True Data') plt.xlabel('Sample') plt.ylabel('Amplitude') - plt.subplot(3, 1, 2) + plt.subplot(2, 2, 3) plt.plot(predicted_data, label='Predicted Data', color='orange') plt.title('Predicted Data') plt.xlabel('Sample') plt.ylabel('Amplitude') - plt.subplot(3, 1, 3) + plt.subplot(2, 2, 4) plt.plot(delta_data, label='Delta', color='red') plt.title('Delta') plt.xlabel('Sample') plt.ylabel('Amplitude') + plt.subplot(2, 2, 2) + plt.plot(true_data, label='True Data') + plt.plot(predicted_data, label='Predicted Data', color='orange') + plt.plot(delta_data, label='Delta', color='red') + plt.title('Combined Data') + plt.xlabel('Sample') + plt.ylabel('Amplitude') + plt.legend() + plt.tight_layout() tmp_dir = os.getenv('TMPDIR', '/tmp') file_path = os.path.join(tmp_dir, f'prediction_plot_{np.random.randint(1e6)}.png') plt.savefig(file_path) plt.close() - wandb.log({"Prediction vs True Data": wandb.Image(file_path)}, step=epoch) + wandb.log({f"Prediction vs True Data {name}": wandb.Image(file_path)}, step=epoch) + def plot_delta_distribution(deltas, epoch): """Plot the distribution of deltas."""