Support for new BinomialHuffman
This commit is contained in:
		
							parent
							
								
									8576f5b741
								
							
						
					
					
						commit
						102ddb8c85
					
				
							
								
								
									
										56
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								main.py
									
									
									
									
									
								
							@ -47,6 +47,7 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
        latent_size = slate.consume(config, 'latent_projector.latent_size')
 | 
					        latent_size = slate.consume(config, 'latent_projector.latent_size')
 | 
				
			||||||
        input_size = slate.consume(config, 'feature_extractor.input_size')
 | 
					        input_size = slate.consume(config, 'feature_extractor.input_size')
 | 
				
			||||||
        region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
 | 
					        region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
 | 
				
			||||||
 | 
					        self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
 | 
				
			||||||
        device = slate.consume(training_config, 'device', 'auto')
 | 
					        device = slate.consume(training_config, 'device', 'auto')
 | 
				
			||||||
        if device == 'auto':
 | 
					        if device == 'auto':
 | 
				
			||||||
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
					            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
				
			||||||
@ -71,10 +72,10 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
					        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
				
			||||||
        self.num_batches = slate.consume(training_config, 'num_batches')
 | 
					        self.num_batches = slate.consume(training_config, 'num_batches')
 | 
				
			||||||
        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
					        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
				
			||||||
        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
					        self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
 | 
				
			||||||
        self.save_path = slate.consume(training_config, 'save_path')
 | 
					        self.save_path = slate.consume(training_config, 'save_path')
 | 
				
			||||||
        self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
 | 
					        self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
 | 
				
			||||||
        self.value_scale = slate.consume(training_config, 'value_scale')
 | 
					        self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Evaluation parameter
 | 
					        # Evaluation parameter
 | 
				
			||||||
        self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
 | 
					        self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
 | 
				
			||||||
@ -93,8 +94,7 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
            self.encoder = RiceEncoder()
 | 
					            self.encoder = RiceEncoder()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise Exception('No such Encoder')
 | 
					            raise Exception('No such Encoder')
 | 
				
			||||||
 | 
					        self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
 | 
				
			||||||
        self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Optimizer
 | 
					        # Optimizer
 | 
				
			||||||
        self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
 | 
					        self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
 | 
				
			||||||
@ -154,11 +154,13 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
                        # Stack the segments to form the batch
 | 
					                        # Stack the segments to form the batch
 | 
				
			||||||
                        stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
					                        stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
				
			||||||
                        stacked_segments.append(stacked_segment)
 | 
					                        stacked_segments.append(stacked_segment)
 | 
				
			||||||
                        target = lead_data[i + self.input_size + 1]
 | 
					                        target = lead_data[i + self.input_size]
 | 
				
			||||||
                        targets.append(target)
 | 
					                        targets.append(target)
 | 
				
			||||||
                        last = lead_data[i + self.input_size]
 | 
					                        last = lead_data[i + self.input_size - 1]
 | 
				
			||||||
                        lasts.append(last)
 | 
					                        lasts.append(last)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                inp = torch.stack(stacked_segments) / self.value_scale
 | 
					                inp = torch.stack(stacked_segments) / self.value_scale
 | 
				
			||||||
                feat = self.feat(inp)
 | 
					                feat = self.feat(inp)
 | 
				
			||||||
                latents = self.projector(feat)
 | 
					                latents = self.projector(feat)
 | 
				
			||||||
@ -178,12 +180,14 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
					                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
				
			||||||
                prediction = self.predictor(region_latent)*self.value_scale
 | 
					                prediction = self.predictor(region_latent)*self.value_scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self.delta_shift:
 | 
				
			||||||
 | 
					                    prediction = prediction + las
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # Calculate loss and backpropagate
 | 
					                # Calculate loss and backpropagate
 | 
				
			||||||
                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
					                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
				
			||||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
 | 
					 | 
				
			||||||
                loss = self.criterion(prediction, tar)
 | 
					                loss = self.criterion(prediction, tar)
 | 
				
			||||||
                err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
					                err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
				
			||||||
                derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
 | 
					                derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
				
			||||||
                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
					                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
				
			||||||
                total_loss += loss.item()
 | 
					                total_loss += loss.item()
 | 
				
			||||||
                derrs.append(derr/np.prod(tar.size()).item())
 | 
					                derrs.append(derr/np.prod(tar.size()).item())
 | 
				
			||||||
@ -226,9 +230,7 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
        all_true = []
 | 
					        all_true = []
 | 
				
			||||||
        all_predicted = []
 | 
					        all_predicted = []
 | 
				
			||||||
        all_deltas = []
 | 
					        all_deltas = []
 | 
				
			||||||
        compression_ratios = []
 | 
					        all_steps = []
 | 
				
			||||||
        exact_matches = 0
 | 
					 | 
				
			||||||
        total_sequences = 0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            min_length = min([len(seq) for seq in self.test_data])
 | 
					            min_length = min([len(seq) for seq in self.test_data])
 | 
				
			||||||
@ -261,11 +263,13 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                    stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
					                    stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
				
			||||||
                    stacked_segments.append(stacked_segment)
 | 
					                    stacked_segments.append(stacked_segment)
 | 
				
			||||||
                    target = lead_data[i + self.input_size + 1]
 | 
					                    target = lead_data[i + self.input_size]
 | 
				
			||||||
                    targets.append(target)
 | 
					                    targets.append(target)
 | 
				
			||||||
                    last = lead_data[i + self.input_size]
 | 
					                    last = lead_data[i + self.input_size - 1]
 | 
				
			||||||
                    lasts.append(last)
 | 
					                    lasts.append(last)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                inp = torch.stack(stacked_segments) / self.value_scale
 | 
					                inp = torch.stack(stacked_segments) / self.value_scale
 | 
				
			||||||
                feat = self.feat(inp)
 | 
					                feat = self.feat(inp)
 | 
				
			||||||
                latents = self.projector(feat)
 | 
					                latents = self.projector(feat)
 | 
				
			||||||
@ -276,11 +280,15 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
					                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
				
			||||||
                prediction = self.predictor(region_latent) * self.value_scale
 | 
					                prediction = self.predictor(region_latent) * self.value_scale
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if self.delta_shift:
 | 
				
			||||||
 | 
					                    prediction = prediction + las
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
					                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
				
			||||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
 | 
					 | 
				
			||||||
                loss = self.criterion(prediction, tar)
 | 
					                loss = self.criterion(prediction, tar)
 | 
				
			||||||
                err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
					                delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
 | 
				
			||||||
                derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
 | 
					                err = np.sum(np.abs(delta))
 | 
				
			||||||
 | 
					                derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
				
			||||||
 | 
					                step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
 | 
				
			||||||
                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
					                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
				
			||||||
                total_loss += loss.item()
 | 
					                total_loss += loss.item()
 | 
				
			||||||
                derrs.append(derr / np.prod(tar.size()).item())
 | 
					                derrs.append(derr / np.prod(tar.size()).item())
 | 
				
			||||||
@ -289,13 +297,15 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                all_true.extend(tar.cpu().numpy())
 | 
					                all_true.extend(tar.cpu().numpy())
 | 
				
			||||||
                all_predicted.extend(prediction.cpu().numpy())
 | 
					                all_predicted.extend(prediction.cpu().numpy())
 | 
				
			||||||
                all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
 | 
					                all_deltas.extend(delta.tolist())
 | 
				
			||||||
 | 
					                all_steps.extend(step.tolist())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if self.full_compression:
 | 
					        if self.full_compression:
 | 
				
			||||||
                    raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
 | 
					            self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
 | 
				
			||||||
                    comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
 | 
					            raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
 | 
				
			||||||
                    ratio = raw_l / comp_l
 | 
					            comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
 | 
				
			||||||
                    wandb.log({"eval/ratio": ratio}, step=epoch)
 | 
					            ratio = raw_l / comp_l
 | 
				
			||||||
 | 
					            wandb.log({"eval/ratio": ratio}, step=epoch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        avg_loss = total_loss / len(self.test_data)
 | 
					        avg_loss = total_loss / len(self.test_data)
 | 
				
			||||||
        tot_err = sum(errs) / len(errs)
 | 
					        tot_err = sum(errs) / len(errs)
 | 
				
			||||||
@ -308,7 +318,7 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Visualize predictions
 | 
					        # Visualize predictions
 | 
				
			||||||
        #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
 | 
					        #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
 | 
				
			||||||
        img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
 | 
					        img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
 | 
					            wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
 | 
				
			||||||
        except:
 | 
					        except:
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user