Support for new BinomialHuffman

This commit is contained in:
Dominik Moritz Roth 2024-05-29 21:12:07 +02:00
parent 8576f5b741
commit 102ddb8c85

56
main.py
View File

@ -47,6 +47,7 @@ class SpikeRunner(Slate_Runner):
latent_size = slate.consume(config, 'latent_projector.latent_size') latent_size = slate.consume(config, 'latent_projector.latent_size')
input_size = slate.consume(config, 'feature_extractor.input_size') input_size = slate.consume(config, 'feature_extractor.input_size')
region_latent_size = slate.consume(config, 'middle_out.region_latent_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
device = slate.consume(training_config, 'device', 'auto') device = slate.consume(training_config, 'device', 'auto')
if device == 'auto': if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@ -71,10 +72,10 @@ class SpikeRunner(Slate_Runner):
self.batch_size = slate.consume(training_config, 'batch_size') self.batch_size = slate.consume(training_config, 'batch_size')
self.num_batches = slate.consume(training_config, 'num_batches') self.num_batches = slate.consume(training_config, 'num_batches')
self.learning_rate = slate.consume(training_config, 'learning_rate') self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq') self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
self.save_path = slate.consume(training_config, 'save_path') self.save_path = slate.consume(training_config, 'save_path')
self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0)) self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
self.value_scale = slate.consume(training_config, 'value_scale') self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
# Evaluation parameter # Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
@ -93,8 +94,7 @@ class SpikeRunner(Slate_Runner):
self.encoder = RiceEncoder() self.encoder = RiceEncoder()
else: else:
raise Exception('No such Encoder') raise Exception('No such Encoder')
self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
# Optimizer # Optimizer
self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
@ -154,11 +154,13 @@ class SpikeRunner(Slate_Runner):
# Stack the segments to form the batch # Stack the segments to form the batch
stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment) stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1] target = lead_data[i + self.input_size]
targets.append(target) targets.append(target)
last = lead_data[i + self.input_size] last = lead_data[i + self.input_size - 1]
lasts.append(last) lasts.append(last)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
inp = torch.stack(stacked_segments) / self.value_scale inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp) feat = self.feat(inp)
latents = self.projector(feat) latents = self.projector(feat)
@ -178,12 +180,14 @@ class SpikeRunner(Slate_Runner):
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent)*self.value_scale prediction = self.predictor(region_latent)*self.value_scale
if self.delta_shift:
prediction = prediction + las
# Calculate loss and backpropagate # Calculate loss and backpropagate
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
loss = self.criterion(prediction, tar) loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
rel = err / np.sum(tar.cpu().detach().numpy()) rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item() total_loss += loss.item()
derrs.append(derr/np.prod(tar.size()).item()) derrs.append(derr/np.prod(tar.size()).item())
@ -226,9 +230,7 @@ class SpikeRunner(Slate_Runner):
all_true = [] all_true = []
all_predicted = [] all_predicted = []
all_deltas = [] all_deltas = []
compression_ratios = [] all_steps = []
exact_matches = 0
total_sequences = 0
with torch.no_grad(): with torch.no_grad():
min_length = min([len(seq) for seq in self.test_data]) min_length = min([len(seq) for seq in self.test_data])
@ -261,11 +263,13 @@ class SpikeRunner(Slate_Runner):
stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segment = torch.stack([inputs] + peer_segments).to(device)
stacked_segments.append(stacked_segment) stacked_segments.append(stacked_segment)
target = lead_data[i + self.input_size + 1] target = lead_data[i + self.input_size]
targets.append(target) targets.append(target)
last = lead_data[i + self.input_size] last = lead_data[i + self.input_size - 1]
lasts.append(last) lasts.append(last)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
inp = torch.stack(stacked_segments) / self.value_scale inp = torch.stack(stacked_segments) / self.value_scale
feat = self.feat(inp) feat = self.feat(inp)
latents = self.projector(feat) latents = self.projector(feat)
@ -276,11 +280,15 @@ class SpikeRunner(Slate_Runner):
region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
prediction = self.predictor(region_latent) * self.value_scale prediction = self.predictor(region_latent) * self.value_scale
if self.delta_shift:
prediction = prediction + las
tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
loss = self.criterion(prediction, tar) loss = self.criterion(prediction, tar)
err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) err = np.sum(np.abs(delta))
derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
rel = err / np.sum(tar.cpu().detach().numpy()) rel = err / np.sum(tar.cpu().detach().numpy())
total_loss += loss.item() total_loss += loss.item()
derrs.append(derr / np.prod(tar.size()).item()) derrs.append(derr / np.prod(tar.size()).item())
@ -289,13 +297,15 @@ class SpikeRunner(Slate_Runner):
all_true.extend(tar.cpu().numpy()) all_true.extend(tar.cpu().numpy())
all_predicted.extend(prediction.cpu().numpy()) all_predicted.extend(prediction.cpu().numpy())
all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) all_deltas.extend(delta.tolist())
all_steps.extend(step.tolist())
if self.full_compression: if self.full_compression:
raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16 self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
comp_l = len(self.encoder.encode(np.concatenate(all_deltas))) raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
ratio = raw_l / comp_l comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
wandb.log({"eval/ratio": ratio}, step=epoch) ratio = raw_l / comp_l
wandb.log({"eval/ratio": ratio}, step=epoch)
avg_loss = total_loss / len(self.test_data) avg_loss = total_loss / len(self.test_data)
tot_err = sum(errs) / len(errs) tot_err = sum(errs) / len(errs)
@ -308,7 +318,7 @@ class SpikeRunner(Slate_Runner):
# Visualize predictions # Visualize predictions
#visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195) img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
try: try:
wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch) wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
except: except: