A bunch of new things

This commit is contained in:
Dominik Moritz Roth 2024-05-28 12:53:33 +02:00
parent d35e3293fa
commit ef11acb1f6
2 changed files with 45 additions and 28 deletions

71
main.py
View File

@ -4,9 +4,9 @@ import torch.nn as nn
import numpy as np
import random, math
from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
import wandb
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
@ -45,14 +45,14 @@ class SpikeRunner(Slate_Runner):
print("Setting up models")
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size')
input_size = slate.consume(config, 'feature_extractor.input_size')
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
device = slate.consume(training_config, 'device', 'auto')
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
feature_size = self.feat.compute_output_size()
if latent_projector_type == 'fc':
@ -80,7 +80,7 @@ class SpikeRunner(Slate_Runner):
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
# Bitstream encoding
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
if bitstream_type == 'identity':
self.encoder = IdentityEncoder()
elif bitstream_type == 'arithmetic':
@ -89,9 +89,13 @@ class SpikeRunner(Slate_Runner):
self.encoder = Bzip2Encoder()
elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder()
elif bitstream_type == 'rice':
self.encoder = RiceEncoder()
else:
raise Exception('No such Encoder')
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
# Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss()
@ -155,8 +159,9 @@ class SpikeRunner(Slate_Runner):
last = lead_data[i + self.input_size]
lasts.append(last)
# Pass the batch through the projector
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
@ -192,7 +197,7 @@ class SpikeRunner(Slate_Runner):
tot_derr = sum(derrs)/len(derrs)
adv_delta = tot_derr / tot_err
approx_ratio = 1/(sum(rels)/len(rels))
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
@ -232,11 +237,11 @@ class SpikeRunner(Slate_Runner):
rels = []
derrs = []
for lead_idx in range(len(self.test_data)):
lead_data = self.test_data[lead_idx][:min_length]
indices = list(range(len(self.test_data)))
random.shuffle(indices)
indices = list(range(len(self.test_data)))
random.shuffle(indices)
for lead_idx in indices[:16]:
lead_data = self.test_data[lead_idx][:min_length]
stacked_segments = []
peer_metrics = []
@ -261,7 +266,9 @@ class SpikeRunner(Slate_Runner):
last = lead_data[i + self.input_size]
lasts.append(last)
latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
@ -285,8 +292,10 @@ class SpikeRunner(Slate_Runner):
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
if self.full_compression:
raw = self.all_data
comp = self.compress(raw)
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
ratio = raw_l / comp_l
wandb.log({"eval/ratio": ratio}, step=epoch)
avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs)
@ -295,24 +304,31 @@ class SpikeRunner(Slate_Runner):
approx_ratio = 1 / (sum(rels) / len(rels))
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
# Visualize predictions
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
try:
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
except:
pass
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
# Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
try:
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
except:
pass
if self.full_compression:
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
exact_match_percentage = (exact_matches / total_sequences) * 100
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
#if self.full_compression:
# avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
# exact_match_percentage = (exact_matches / total_sequences) * 100
# print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
# print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
# wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
# wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
# Restore the original mode of the models
if projector_mode:
@ -344,6 +360,7 @@ class SpikeRunner(Slate_Runner):
def compress(raw):
threads = unfuckify_all(raw)
for thread in threads:
pass
# 1. featExtr
# 2. latentProj
# 3. middleOut

View File

@ -70,7 +70,7 @@ class FeatureExtractor(nn.Module):
size += length
elif transform[0] == 'fourier':
_, length = transform
size += length * 2 # Fourier transform outputs both real and imaginary parts
size += length * 2
elif transform[0] == 'wavelet':
_, wavelet_type, length = transform
# Find the true size of the wavelet coefficients