A bunch of new things
This commit is contained in:
parent
d35e3293fa
commit
ef11acb1f6
71
main.py
71
main.py
@ -4,9 +4,9 @@ import torch.nn as nn
|
||||
import numpy as np
|
||||
import random, math
|
||||
from utils import visualize_prediction, plot_delta_distribution
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
|
||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
|
||||
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
|
||||
import wandb
|
||||
from pycallgraph2 import PyCallGraph
|
||||
from pycallgraph2.output import GraphvizOutput
|
||||
@ -45,14 +45,14 @@ class SpikeRunner(Slate_Runner):
|
||||
print("Setting up models")
|
||||
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
|
||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||
input_size = slate.consume(config, 'latent_projector.input_size')
|
||||
input_size = slate.consume(config, 'feature_extractor.input_size')
|
||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||
device = slate.consume(training_config, 'device', 'auto')
|
||||
if device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.device = device
|
||||
|
||||
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
||||
self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
||||
feature_size = self.feat.compute_output_size()
|
||||
|
||||
if latent_projector_type == 'fc':
|
||||
@ -80,7 +80,7 @@ class SpikeRunner(Slate_Runner):
|
||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||
|
||||
# Bitstream encoding
|
||||
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
|
||||
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
|
||||
if bitstream_type == 'identity':
|
||||
self.encoder = IdentityEncoder()
|
||||
elif bitstream_type == 'arithmetic':
|
||||
@ -89,9 +89,13 @@ class SpikeRunner(Slate_Runner):
|
||||
self.encoder = Bzip2Encoder()
|
||||
elif bitstream_type == 'binomHuffman':
|
||||
self.encoder = BinomialHuffmanEncoder()
|
||||
elif bitstream_type == 'rice':
|
||||
self.encoder = RiceEncoder()
|
||||
else:
|
||||
raise Exception('No such Encoder')
|
||||
|
||||
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
||||
self.criterion = torch.nn.MSELoss()
|
||||
@ -155,8 +159,9 @@ class SpikeRunner(Slate_Runner):
|
||||
last = lead_data[i + self.input_size]
|
||||
lasts.append(last)
|
||||
|
||||
# Pass the batch through the projector
|
||||
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
||||
inp = torch.stack(stacked_segments) / self.value_scale
|
||||
feat = self.feat(inp)
|
||||
latents = self.projector(feat)
|
||||
|
||||
my_latent = latents[:, 0, :]
|
||||
peer_latents = latents[:, 1:, :]
|
||||
@ -192,7 +197,7 @@ class SpikeRunner(Slate_Runner):
|
||||
tot_derr = sum(derrs)/len(derrs)
|
||||
adv_delta = tot_derr / tot_err
|
||||
approx_ratio = 1/(sum(rels)/len(rels))
|
||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
||||
wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
|
||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||
|
||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||
@ -232,11 +237,11 @@ class SpikeRunner(Slate_Runner):
|
||||
rels = []
|
||||
derrs = []
|
||||
|
||||
for lead_idx in range(len(self.test_data)):
|
||||
lead_data = self.test_data[lead_idx][:min_length]
|
||||
indices = list(range(len(self.test_data)))
|
||||
random.shuffle(indices)
|
||||
|
||||
indices = list(range(len(self.test_data)))
|
||||
random.shuffle(indices)
|
||||
for lead_idx in indices[:16]:
|
||||
lead_data = self.test_data[lead_idx][:min_length]
|
||||
|
||||
stacked_segments = []
|
||||
peer_metrics = []
|
||||
@ -261,7 +266,9 @@ class SpikeRunner(Slate_Runner):
|
||||
last = lead_data[i + self.input_size]
|
||||
lasts.append(last)
|
||||
|
||||
latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
|
||||
inp = torch.stack(stacked_segments) / self.value_scale
|
||||
feat = self.feat(inp)
|
||||
latents = self.projector(feat)
|
||||
|
||||
my_latent = latents[:, 0, :]
|
||||
peer_latents = latents[:, 1:, :]
|
||||
@ -285,8 +292,10 @@ class SpikeRunner(Slate_Runner):
|
||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||
|
||||
if self.full_compression:
|
||||
raw = self.all_data
|
||||
comp = self.compress(raw)
|
||||
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
|
||||
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
|
||||
ratio = raw_l / comp_l
|
||||
wandb.log({"eval/ratio": ratio}, step=epoch)
|
||||
|
||||
avg_loss = total_loss / len(self.test_data)
|
||||
tot_err = sum(errs) / len(errs)
|
||||
@ -295,24 +304,31 @@ class SpikeRunner(Slate_Runner):
|
||||
approx_ratio = 1 / (sum(rels) / len(rels))
|
||||
|
||||
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
||||
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
||||
wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
|
||||
|
||||
# Visualize predictions
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
|
||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
|
||||
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
||||
img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
|
||||
try:
|
||||
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
|
||||
except:
|
||||
pass
|
||||
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
|
||||
|
||||
# Plot delta distribution
|
||||
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||
try:
|
||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||
except:
|
||||
pass
|
||||
|
||||
if self.full_compression:
|
||||
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
||||
exact_match_percentage = (exact_matches / total_sequences) * 100
|
||||
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
||||
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
||||
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
||||
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
||||
#if self.full_compression:
|
||||
# avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
||||
# exact_match_percentage = (exact_matches / total_sequences) * 100
|
||||
# print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
||||
# print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
||||
# wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
||||
# wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
||||
|
||||
# Restore the original mode of the models
|
||||
if projector_mode:
|
||||
@ -344,6 +360,7 @@ class SpikeRunner(Slate_Runner):
|
||||
def compress(raw):
|
||||
threads = unfuckify_all(raw)
|
||||
for thread in threads:
|
||||
pass
|
||||
# 1. featExtr
|
||||
# 2. latentProj
|
||||
# 3. middleOut
|
||||
|
@ -70,7 +70,7 @@ class FeatureExtractor(nn.Module):
|
||||
size += length
|
||||
elif transform[0] == 'fourier':
|
||||
_, length = transform
|
||||
size += length * 2 # Fourier transform outputs both real and imaginary parts
|
||||
size += length * 2
|
||||
elif transform[0] == 'wavelet':
|
||||
_, wavelet_type, length = transform
|
||||
# Find the true size of the wavelet coefficients
|
||||
|
Loading…
Reference in New Issue
Block a user