From 102ddb8c8548de31be5e9e54b1d281eab5adbb2d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 May 2024 21:12:07 +0200 Subject: [PATCH] Support for new BinomialHuffman --- main.py | 56 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/main.py b/main.py index 308bfd0..98f567a 100644 --- a/main.py +++ b/main.py @@ -47,6 +47,7 @@ class SpikeRunner(Slate_Runner): latent_size = slate.consume(config, 'latent_projector.latent_size') input_size = slate.consume(config, 'feature_extractor.input_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size') + self.delta_shift = slate.consume(config, 'predictor.delta_shift', True) device = slate.consume(training_config, 'device', 'auto') if device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -71,10 +72,10 @@ class SpikeRunner(Slate_Runner): self.batch_size = slate.consume(training_config, 'batch_size') self.num_batches = slate.consume(training_config, 'num_batches') self.learning_rate = slate.consume(training_config, 'learning_rate') - self.eval_freq = slate.consume(training_config, 'eval_freq') + self.eval_freq = slate.consume(training_config, 'eval_freq', -1) self.save_path = slate.consume(training_config, 'save_path') self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0)) - self.value_scale = slate.consume(training_config, 'value_scale') + self.value_scale = slate.consume(training_config, 'value_scale', 1.0) # Evaluation parameter self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) @@ -93,8 +94,7 @@ class SpikeRunner(Slate_Runner): self.encoder = RiceEncoder() else: raise Exception('No such Encoder') - - self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding')) + self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding') # Optimizer self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate) @@ -154,11 +154,13 @@ class SpikeRunner(Slate_Runner): # Stack the segments to form the batch stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segments.append(stacked_segment) - target = lead_data[i + self.input_size + 1] + target = lead_data[i + self.input_size] targets.append(target) - last = lead_data[i + self.input_size] + last = lead_data[i + self.input_size - 1] lasts.append(last) + las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device) + inp = torch.stack(stacked_segments) / self.value_scale feat = self.feat(inp) latents = self.projector(feat) @@ -178,12 +180,14 @@ class SpikeRunner(Slate_Runner): region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) prediction = self.predictor(region_latent)*self.value_scale + if self.delta_shift: + prediction = prediction + las + # Calculate loss and backpropagate tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) - las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy() loss = self.criterion(prediction, tar) err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) - derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) + derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy())) rel = err / np.sum(tar.cpu().detach().numpy()) total_loss += loss.item() derrs.append(derr/np.prod(tar.size()).item()) @@ -226,9 +230,7 @@ class SpikeRunner(Slate_Runner): all_true = [] all_predicted = [] all_deltas = [] - compression_ratios = [] - exact_matches = 0 - total_sequences = 0 + all_steps = [] with torch.no_grad(): min_length = min([len(seq) for seq in self.test_data]) @@ -261,11 +263,13 @@ class SpikeRunner(Slate_Runner): stacked_segment = torch.stack([inputs] + peer_segments).to(device) stacked_segments.append(stacked_segment) - target = lead_data[i + self.input_size + 1] + target = lead_data[i + self.input_size] targets.append(target) - last = lead_data[i + self.input_size] + last = lead_data[i + self.input_size - 1] lasts.append(last) + las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device) + inp = torch.stack(stacked_segments) / self.value_scale feat = self.feat(inp) latents = self.projector(feat) @@ -276,11 +280,15 @@ class SpikeRunner(Slate_Runner): region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) prediction = self.predictor(region_latent) * self.value_scale + if self.delta_shift: + prediction = prediction + las + tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) - las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy() loss = self.criterion(prediction, tar) - err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) - derr = np.sum(np.abs(las - tar.cpu().detach().numpy())) + delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy() + err = np.sum(np.abs(delta)) + derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy())) + step = las.cpu().detach().numpy() - tar.cpu().detach().numpy() rel = err / np.sum(tar.cpu().detach().numpy()) total_loss += loss.item() derrs.append(derr / np.prod(tar.size()).item()) @@ -289,13 +297,15 @@ class SpikeRunner(Slate_Runner): all_true.extend(tar.cpu().numpy()) all_predicted.extend(prediction.cpu().numpy()) - all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist()) + all_deltas.extend(delta.tolist()) + all_steps.extend(step.tolist()) - if self.full_compression: - raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16 - comp_l = len(self.encoder.encode(np.concatenate(all_deltas))) - ratio = raw_l / comp_l - wandb.log({"eval/ratio": ratio}, step=epoch) + if self.full_compression: + self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config) + raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16 + comp_l = len(self.encoder.encode(np.concatenate(all_deltas))) + ratio = raw_l / comp_l + wandb.log({"eval/ratio": ratio}, step=epoch) avg_loss = total_loss / len(self.test_data) tot_err = sum(errs) / len(errs) @@ -308,7 +318,7 @@ class SpikeRunner(Slate_Runner): # Visualize predictions #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s') - img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195) + img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195) try: wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch) except: