A bunch of new things
This commit is contained in:
parent
d35e3293fa
commit
ef11acb1f6
71
main.py
71
main.py
@ -4,9 +4,9 @@ import torch.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random, math
|
import random, math
|
||||||
from utils import visualize_prediction, plot_delta_distribution
|
from utils import visualize_prediction, plot_delta_distribution
|
||||||
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all
|
from data_processing import download_and_extract_data, load_all_wavs, split_data_by_time, compute_topology_metrics, unfuckify_all, refuckify
|
||||||
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
|
from models import LatentFCProjector, LatentRNNProjector, MiddleOut, Predictor, FeatureExtractor
|
||||||
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder
|
from bitstream import IdentityEncoder, ArithmeticEncoder, Bzip2Encoder, BinomialHuffmanEncoder, RiceEncoder
|
||||||
import wandb
|
import wandb
|
||||||
from pycallgraph2 import PyCallGraph
|
from pycallgraph2 import PyCallGraph
|
||||||
from pycallgraph2.output import GraphvizOutput
|
from pycallgraph2.output import GraphvizOutput
|
||||||
@ -45,14 +45,14 @@ class SpikeRunner(Slate_Runner):
|
|||||||
print("Setting up models")
|
print("Setting up models")
|
||||||
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
|
latent_projector_type = slate.consume(config, 'latent_projector.type', default='fc')
|
||||||
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
latent_size = slate.consume(config, 'latent_projector.latent_size')
|
||||||
input_size = slate.consume(config, 'latent_projector.input_size')
|
input_size = slate.consume(config, 'feature_extractor.input_size')
|
||||||
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
|
||||||
device = slate.consume(training_config, 'device', 'auto')
|
device = slate.consume(training_config, 'device', 'auto')
|
||||||
if device == 'auto':
|
if device == 'auto':
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.feat = FeatureExtractor(**slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
self.feat = FeatureExtractor(input_size=input_size, **slate.consume(config, 'feature_extractor', expand=True)).to(device)
|
||||||
feature_size = self.feat.compute_output_size()
|
feature_size = self.feat.compute_output_size()
|
||||||
|
|
||||||
if latent_projector_type == 'fc':
|
if latent_projector_type == 'fc':
|
||||||
@ -80,7 +80,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||||
|
|
||||||
# Bitstream encoding
|
# Bitstream encoding
|
||||||
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='identity')
|
bitstream_type = slate.consume(config, 'bitstream_encoding.type', default='rice')
|
||||||
if bitstream_type == 'identity':
|
if bitstream_type == 'identity':
|
||||||
self.encoder = IdentityEncoder()
|
self.encoder = IdentityEncoder()
|
||||||
elif bitstream_type == 'arithmetic':
|
elif bitstream_type == 'arithmetic':
|
||||||
@ -89,9 +89,13 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.encoder = Bzip2Encoder()
|
self.encoder = Bzip2Encoder()
|
||||||
elif bitstream_type == 'binomHuffman':
|
elif bitstream_type == 'binomHuffman':
|
||||||
self.encoder = BinomialHuffmanEncoder()
|
self.encoder = BinomialHuffmanEncoder()
|
||||||
|
elif bitstream_type == 'rice':
|
||||||
|
self.encoder = RiceEncoder()
|
||||||
else:
|
else:
|
||||||
raise Exception('No such Encoder')
|
raise Exception('No such Encoder')
|
||||||
|
|
||||||
|
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
|
||||||
self.criterion = torch.nn.MSELoss()
|
self.criterion = torch.nn.MSELoss()
|
||||||
@ -155,8 +159,9 @@ class SpikeRunner(Slate_Runner):
|
|||||||
last = lead_data[i + self.input_size]
|
last = lead_data[i + self.input_size]
|
||||||
lasts.append(last)
|
lasts.append(last)
|
||||||
|
|
||||||
# Pass the batch through the projector
|
inp = torch.stack(stacked_segments) / self.value_scale
|
||||||
latents = self.projector(torch.stack(stacked_segments)/self.value_scale)
|
feat = self.feat(inp)
|
||||||
|
latents = self.projector(feat)
|
||||||
|
|
||||||
my_latent = latents[:, 0, :]
|
my_latent = latents[:, 0, :]
|
||||||
peer_latents = latents[:, 1:, :]
|
peer_latents = latents[:, 1:, :]
|
||||||
@ -192,7 +197,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
tot_derr = sum(derrs)/len(derrs)
|
tot_derr = sum(derrs)/len(derrs)
|
||||||
adv_delta = tot_derr / tot_err
|
adv_delta = tot_derr / tot_err
|
||||||
approx_ratio = 1/(sum(rels)/len(rels))
|
approx_ratio = 1/(sum(rels)/len(rels))
|
||||||
wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
wandb.log({"train/epoch": epoch, "train/loss": total_loss, "train/err": tot_err, "train/approx_ratio": approx_ratio, "train/adv_delta": adv_delta}, step=epoch)
|
||||||
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}')
|
||||||
|
|
||||||
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0:
|
||||||
@ -232,11 +237,11 @@ class SpikeRunner(Slate_Runner):
|
|||||||
rels = []
|
rels = []
|
||||||
derrs = []
|
derrs = []
|
||||||
|
|
||||||
for lead_idx in range(len(self.test_data)):
|
indices = list(range(len(self.test_data)))
|
||||||
lead_data = self.test_data[lead_idx][:min_length]
|
random.shuffle(indices)
|
||||||
|
|
||||||
indices = list(range(len(self.test_data)))
|
for lead_idx in indices[:16]:
|
||||||
random.shuffle(indices)
|
lead_data = self.test_data[lead_idx][:min_length]
|
||||||
|
|
||||||
stacked_segments = []
|
stacked_segments = []
|
||||||
peer_metrics = []
|
peer_metrics = []
|
||||||
@ -261,7 +266,9 @@ class SpikeRunner(Slate_Runner):
|
|||||||
last = lead_data[i + self.input_size]
|
last = lead_data[i + self.input_size]
|
||||||
lasts.append(last)
|
lasts.append(last)
|
||||||
|
|
||||||
latents = self.projector(torch.stack(stacked_segments) / self.value_scale)
|
inp = torch.stack(stacked_segments) / self.value_scale
|
||||||
|
feat = self.feat(inp)
|
||||||
|
latents = self.projector(feat)
|
||||||
|
|
||||||
my_latent = latents[:, 0, :]
|
my_latent = latents[:, 0, :]
|
||||||
peer_latents = latents[:, 1:, :]
|
peer_latents = latents[:, 1:, :]
|
||||||
@ -285,8 +292,10 @@ class SpikeRunner(Slate_Runner):
|
|||||||
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
|
||||||
|
|
||||||
if self.full_compression:
|
if self.full_compression:
|
||||||
raw = self.all_data
|
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
|
||||||
comp = self.compress(raw)
|
comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
|
||||||
|
ratio = raw_l / comp_l
|
||||||
|
wandb.log({"eval/ratio": ratio}, step=epoch)
|
||||||
|
|
||||||
avg_loss = total_loss / len(self.test_data)
|
avg_loss = total_loss / len(self.test_data)
|
||||||
tot_err = sum(errs) / len(errs)
|
tot_err = sum(errs) / len(errs)
|
||||||
@ -295,24 +304,31 @@ class SpikeRunner(Slate_Runner):
|
|||||||
approx_ratio = 1 / (sum(rels) / len(rels))
|
approx_ratio = 1 / (sum(rels) / len(rels))
|
||||||
|
|
||||||
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
print(f'Epoch {epoch+1}, Evaluation Loss: {avg_loss}')
|
||||||
wandb.log({"evaluation_loss": avg_loss, "err": tot_err, "approx_ratio": approx_ratio, "adv_delta": adv_delta}, step=epoch)
|
wandb.log({"eval/loss": avg_loss, "eval/err": tot_err, "eval/approx_ratio": approx_ratio, "eval/adv_delta": adv_delta}, step=epoch)
|
||||||
|
|
||||||
# Visualize predictions
|
# Visualize predictions
|
||||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
|
||||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195, name='0.01s')
|
img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
|
||||||
visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
|
try:
|
||||||
|
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=20, name='0.001s')
|
||||||
|
|
||||||
# Plot delta distribution
|
# Plot delta distribution
|
||||||
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
|
delta_plot_path = plot_delta_distribution(np.array(all_deltas), epoch)
|
||||||
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
try:
|
||||||
|
wandb.log({"delta_distribution": wandb.Image(delta_plot_path)}, step=epoch)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if self.full_compression:
|
#if self.full_compression:
|
||||||
avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
# avg_compression_ratio = sum(compression_ratios) / len(compression_ratios)
|
||||||
exact_match_percentage = (exact_matches / total_sequences) * 100
|
# exact_match_percentage = (exact_matches / total_sequences) * 100
|
||||||
print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
# print(f'Epoch {epoch+1}, Average Compression Ratio: {avg_compression_ratio}')
|
||||||
print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
# print(f'Epoch {epoch+1}, Exact Match Percentage: {exact_match_percentage}%')
|
||||||
wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
# wandb.log({"average_compression_ratio": avg_compression_ratio}, step=epoch)
|
||||||
wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
# wandb.log({"exact_match_percentage": exact_match_percentage}, step=epoch)
|
||||||
|
|
||||||
# Restore the original mode of the models
|
# Restore the original mode of the models
|
||||||
if projector_mode:
|
if projector_mode:
|
||||||
@ -344,6 +360,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
def compress(raw):
|
def compress(raw):
|
||||||
threads = unfuckify_all(raw)
|
threads = unfuckify_all(raw)
|
||||||
for thread in threads:
|
for thread in threads:
|
||||||
|
pass
|
||||||
# 1. featExtr
|
# 1. featExtr
|
||||||
# 2. latentProj
|
# 2. latentProj
|
||||||
# 3. middleOut
|
# 3. middleOut
|
||||||
|
@ -70,7 +70,7 @@ class FeatureExtractor(nn.Module):
|
|||||||
size += length
|
size += length
|
||||||
elif transform[0] == 'fourier':
|
elif transform[0] == 'fourier':
|
||||||
_, length = transform
|
_, length = transform
|
||||||
size += length * 2 # Fourier transform outputs both real and imaginary parts
|
size += length * 2
|
||||||
elif transform[0] == 'wavelet':
|
elif transform[0] == 'wavelet':
|
||||||
_, wavelet_type, length = transform
|
_, wavelet_type, length = transform
|
||||||
# Find the true size of the wavelet coefficients
|
# Find the true size of the wavelet coefficients
|
||||||
|
Loading…
Reference in New Issue
Block a user