From ef11acb1f6b4afa83648161f16b2fa63fdc660b1 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 28 May 2024 12:53:33 +0200 Subject: [PATCH] A bunch of new things --- main.py | 71 ++++++++++++++++++++++++++++++++++--------------------- models.py | 2 +- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/main.py b/main.py index 3e3552b..308bfd0 100644 --- a/main.py +++ b/main.py @@ -4,9 +4,9 @@ import torch.nn as nn import numpy as np import random, math from utils import visualize_prediction, plot_delta_distribution -from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all +from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor -from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder +from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder import wandb from pycallgraph2 import PyCallGraph from pycallgraph2.output import GraphvizOutput @@ -45,14 +45,14 @@ class SpikeRunner(Slate_Runner): print("Setting up models") latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc') latent_size = slate.consume(config, 'latent_projector.latent_size') - input_size = slate.consume(config, 'latent_projector.input_size') + input_size = slate.consume(config, 'feature_extractor.input_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size') device = slate.consume(training_config, 'device', 'auto') if device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device - self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device) + self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device) feature_size = self.feat.compute_output_size() if latent_projector_type == 'fc': @@ -80,7 +80,7 @@ class SpikeRunner(Slate_Runner): self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) # Bitstream encoding - bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity') + bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice') if bitstream_type == 'identity': self.encoder = IdentityEncoder() elif bitstream_type == 'arithmetic': @@ -89,9 +89,13 @@ class SpikeRunner(Slate_Runner): self.encoder = Bzip2Encoder() elif bitstream_type == 'binomHuffman': self.encoder = BinomialHuffmanEncoder() + elif bitstream_type == 'rice': + self.encoder = RiceEncoder() else: raise Exception('No such Encoder') + self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding')) + # Optimizer self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.criterion = torch.nn.MSELoss() @@ -155,8 +159,9 @@ class SpikeRunner(Slate_Runner): last = lead_data[i + self.input_size] lasts.append(last) - # Pass the batch through the projector - latents = self.projector(torch.stack(stacked_segments)/self.value_scale) + inp = torch.stack(stacked_segments) / self.value_scale + feat = self.feat(inp) + latents = self.projector(feat) my_latent = latents[:, 0, :] peer_latents = latents[:, 1:, :] @@ -192,7 +197,7 @@ class SpikeRunner(Slate_Runner): tot_derr = sum(derrs)/len(derrs) adv_delta = tot_derr / tot_err approx_ratio = 1/(sum(rels)/len(rels)) - wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) + wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch) print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: @@ -232,11 +237,11 @@ class SpikeRunner(Slate_Runner): rels = [] derrs = [] - for lead_idx in range(len(self.test_data)): - lead_data = self.test_data[lead_idx][:min_length] + indices = list(range(len(self.test_data))) + random.shuffle(indices) - indices = list(range(len(self.test_data))) - random.shuffle(indices) + for lead_idx in indices[:16]: + lead_data = self.test_data[lead_idx][:min_length] stacked_segments = [] peer_metrics = [] @@ -261,7 +266,9 @@ class SpikeRunner(Slate_Runner): last = lead_data[i + self.input_size] lasts.append(last) - latents = self.projector(torch.stack(stacked_segments) / self.value_scale) + inp = torch.stack(stacked_segments) / self.value_scale + feat = self.feat(inp) + latents = self.projector(feat) my_latent = latents[:, 0, :] peer_latents = latents[:, 1:, :] @@ -285,8 +292,10 @@ class SpikeRunner(Slate_Runner): all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) if self.full_compression: - raw = self.all_data - comp = self.compress(raw) + raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16 + comp_l = len(self.encoder.encode(np.concatenate(all_deltas))) + ratio = raw_l / comp_l + wandb.log({"eval/ratio": ratio}, step=epoch) avg_loss = total_loss / len(self.test_data) tot_err = sum(errs) / len(errs) @@ -295,24 +304,31 @@ class SpikeRunner(Slate_Runner): approx_ratio = 1 / (sum(rels) / len(rels)) print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') - wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) + wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch) # Visualize predictions - visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') - visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s') - visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s') + #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') + img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195) + try: + wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch) + except: + pass + #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s') # Plot delta distribution delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch) - wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) + try: + wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) + except: + pass - if self.full_compression: - avg_compression_ratio = sum(compression_ratios) / len(compression_ratios) - exact_match_percentage = (exact_matches / total_sequences) * 100 - print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}') - print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%') - wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) - wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) + #if self.full_compression: + # avg_compression_ratio = sum(compression_ratios) / len(compression_ratios) + # exact_match_percentage = (exact_matches / total_sequences) * 100 + # print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}') + # print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%') + # wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) + # wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) # Restore the original mode of the models if projector_mode: @@ -344,6 +360,7 @@ class SpikeRunner(Slate_Runner): def compress(raw): threads = unfuckify_all(raw) for thread in threads: + pass # 1. featExtr # 2. latentProj # 3. middleOut diff --git a/models.py b/models.py index d13303d..fcf04ea 100644 --- a/models.py +++ b/models.py @@ -70,7 +70,7 @@ class FeatureExtractor(nn.Module): size += length elif transform[0] == 'fourier': _, length = transform - size += length * 2 # Fourier transform outputs both real and imaginary parts + size += length * 2 elif transform[0] == 'wavelet': _, wavelet_type, length = transform # Find the true size of the wavelet coefficients