A bunch of new things

This commit is contained in:
Dominik Moritz Roth 2024-05-28 12:53:33 +02:00
parent d35e3293fa
commit ef11acb1f6
2 changed files with 45 additions and 28 deletions

67
main.py
View File

@ -4,9 +4,9 @@ import torch.nn as nn
import numpy as np import numpy as np
import random, math import random, math
from utils import visualize_prediction, plot_delta_distribution from utils import visualize_prediction, plot_delta_distribution
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
import wandb import wandb
from pycallgraph2 import PyCallGraph from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput from pycallgraph2.output import GraphvizOutput
@ -45,14 +45,14 @@ class SpikeRunner(Slate_Runner):
print("Setting up models") print("Setting up models")
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc') latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
latent_size = slate.consume(config, 'latent_projector.latent_size') latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'latent_projector.input_size') input_size = slate.consume(config, 'feature_extractor.input_size')
region_latent_size = slate.consume(config, 'middle_out.region_latent_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
device = slate.consume(training_config, 'device', 'auto') device = slate.consume(training_config, 'device', 'auto')
if device == 'auto': if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device self.device = device
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device) self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
feature_size = self.feat.compute_output_size() feature_size = self.feat.compute_output_size()
if latent_projector_type == 'fc': if latent_projector_type == 'fc':
@ -80,7 +80,7 @@ class SpikeRunner(Slate_Runner):
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
# Bitstream encoding # Bitstream encoding
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity') bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
if bitstream_type == 'identity': if bitstream_type == 'identity':
self.encoder = IdentityEncoder() self.encoder = IdentityEncoder()
elif bitstream_type == 'arithmetic': elif bitstream_type == 'arithmetic':
@ -89,9 +89,13 @@ class SpikeRunner(Slate_Runner):
self.encoder = Bzip2Encoder() self.encoder = Bzip2Encoder()
elif bitstream_type == 'binomHuffman': elif bitstream_type == 'binomHuffman':
self.encoder = BinomialHuffmanEncoder() self.encoder = BinomialHuffmanEncoder()
elif bitstream_type == 'rice':
self.encoder = RiceEncoder()
else: else:
raise Exception('No such Encoder') raise Exception('No such Encoder')
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
# Optimizer # Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
self.criterion = torch.nn.MSELoss() self.criterion = torch.nn.MSELoss()
@ -155,8 +159,9 @@ class SpikeRunner(Slate_Runner):
last = lead_data[i + self.input_size] last = lead_data[i + self.input_size]
lasts.append(last) lasts.append(last)
# Pass the batch through the projector inp = torch.stack(stacked_segments) / self.value_scale
latents = self.projector(torch.stack(stacked_segments)/self.value_scale) feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :] my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :] peer_latents = latents[:, 1:, :]
@ -192,7 +197,7 @@ class SpikeRunner(Slate_Runner):
tot_derr = sum(derrs)/len(derrs) tot_derr = sum(derrs)/len(derrs)
adv_delta = tot_derr / tot_err adv_delta = tot_derr / tot_err
approx_ratio = 1/(sum(rels)/len(rels)) approx_ratio = 1/(sum(rels)/len(rels))
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
@ -232,12 +237,12 @@ class SpikeRunner(Slate_Runner):
rels = [] rels = []
derrs = [] derrs = []
for lead_idx in range(len(self.test_data)):
lead_data = self.test_data[lead_idx][:min_length]
indices = list(range(len(self.test_data))) indices = list(range(len(self.test_data)))
random.shuffle(indices) random.shuffle(indices)
for lead_idx in indices[:16]:
lead_data = self.test_data[lead_idx][:min_length]
stacked_segments = [] stacked_segments = []
peer_metrics = [] peer_metrics = []
targets = [] targets = []
@ -261,7 +266,9 @@ class SpikeRunner(Slate_Runner):
last = lead_data[i + self.input_size] last = lead_data[i + self.input_size]
lasts.append(last) lasts.append(last)
latents = self.projector(torch.stack(stacked_segments) / self.value_scale) inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp)
latents = self.projector(feat)
my_latent = latents[:, 0, :] my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :] peer_latents = latents[:, 1:, :]
@ -285,8 +292,10 @@ class SpikeRunner(Slate_Runner):
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
if self.full_compression: if self.full_compression:
raw = self.all_data raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
comp = self.compress(raw) comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
ratio = raw_l / comp_l
wandb.log({"eval/ratio": ratio}, step=epoch)
avg_loss = total_loss / len(self.test_data) avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs) tot_err = sum(errs) / len(errs)
@ -295,24 +304,31 @@ class SpikeRunner(Slate_Runner):
approx_ratio = 1 / (sum(rels) / len(rels)) approx_ratio = 1 / (sum(rels) / len(rels))
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}') print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch) wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
# Visualize predictions # Visualize predictions
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s') img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s') try:
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
except:
pass
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
# Plot delta distribution # Plot delta distribution
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch) delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
try:
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch) wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
except:
pass
if self.full_compression: #if self.full_compression:
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios) # avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
exact_match_percentage = (exact_matches / total_sequences) * 100 # exact_match_percentage = (exact_matches / total_sequences) * 100
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}') # print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%') # print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch) # wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch) # wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
# Restore the original mode of the models # Restore the original mode of the models
if projector_mode: if projector_mode:
@ -344,6 +360,7 @@ class SpikeRunner(Slate_Runner):
def compress(raw): def compress(raw):
threads = unfuckify_all(raw) threads = unfuckify_all(raw)
for thread in threads: for thread in threads:
pass
# 1. featExtr # 1. featExtr
# 2. latentProj # 2. latentProj
# 3. middleOut # 3. middleOut

View File

@ -70,7 +70,7 @@ class FeatureExtractor(nn.Module):
size += length size += length
elif transform[0] == 'fourier': elif transform[0] == 'fourier':
_, length = transform _, length = transform
size += length * 2 # Fourier transform outputs both real and imaginary parts size += length * 2
elif transform[0] == 'wavelet': elif transform[0] == 'wavelet':
_, wavelet_type, length = transform _, wavelet_type, length = transform
# Find the true size of the wavelet coefficients # Find the true size of the wavelet coefficients