Support for new BinomialHuffman
This commit is contained in:
		
							parent
							
								
									8576f5b741
								
							
						
					
					
						commit
						102ddb8c85
					
				
							
								
								
									
										56
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										56
									
								
								main.py
									
									
									
									
									
								
							@ -47,6 +47,7 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
        latent_size = slate.consume(config, 'latent_projector.latent_size')
 | 
			
		||||
        input_size = slate.consume(config, 'feature_extractor.input_size')
 | 
			
		||||
        region_latent_size = slate.consume(config, 'middle_out.region_latent_size')
 | 
			
		||||
        self.delta_shift = slate.consume(config, 'predictor.delta_shift', True)
 | 
			
		||||
        device = slate.consume(training_config, 'device', 'auto')
 | 
			
		||||
        if device == 'auto':
 | 
			
		||||
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
@ -71,10 +72,10 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
        self.batch_size = slate.consume(training_config, 'batch_size')
 | 
			
		||||
        self.num_batches = slate.consume(training_config, 'num_batches')
 | 
			
		||||
        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
			
		||||
        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
			
		||||
        self.eval_freq = slate.consume(training_config, 'eval_freq', -1)
 | 
			
		||||
        self.save_path = slate.consume(training_config, 'save_path')
 | 
			
		||||
        self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0))
 | 
			
		||||
        self.value_scale = slate.consume(training_config, 'value_scale')
 | 
			
		||||
        self.value_scale = slate.consume(training_config, 'value_scale', 1.0)
 | 
			
		||||
 | 
			
		||||
        # Evaluation parameter
 | 
			
		||||
        self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
 | 
			
		||||
@ -93,8 +94,7 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
            self.encoder = RiceEncoder()
 | 
			
		||||
        else:
 | 
			
		||||
            raise Exception('No such Encoder')
 | 
			
		||||
 | 
			
		||||
        self.encoder.build_model(self.all_data, **slate.consume(config, 'bitstream_encoding'))
 | 
			
		||||
        self.bitstream_encoder_config = slate.consume(config, 'bitstream_encoding')
 | 
			
		||||
 | 
			
		||||
        # Optimizer
 | 
			
		||||
        self.optimizer = torch.optim.Adam(list(self.projector.parameters()) + list(self.middle_out.parameters()) + list(self.predictor.parameters()), lr=self.learning_rate)
 | 
			
		||||
@ -154,11 +154,13 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
                        # Stack the segments to form the batch
 | 
			
		||||
                        stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
			
		||||
                        stacked_segments.append(stacked_segment)
 | 
			
		||||
                        target = lead_data[i + self.input_size + 1]
 | 
			
		||||
                        target = lead_data[i + self.input_size]
 | 
			
		||||
                        targets.append(target)
 | 
			
		||||
                        last = lead_data[i + self.input_size]
 | 
			
		||||
                        last = lead_data[i + self.input_size - 1]
 | 
			
		||||
                        lasts.append(last)
 | 
			
		||||
 | 
			
		||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
			
		||||
 | 
			
		||||
                inp = torch.stack(stacked_segments) / self.value_scale
 | 
			
		||||
                feat = self.feat(inp)
 | 
			
		||||
                latents = self.projector(feat)
 | 
			
		||||
@ -178,12 +180,14 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
			
		||||
                prediction = self.predictor(region_latent)*self.value_scale
 | 
			
		||||
 | 
			
		||||
                if self.delta_shift:
 | 
			
		||||
                    prediction = prediction + las
 | 
			
		||||
 | 
			
		||||
                # Calculate loss and backpropagate
 | 
			
		||||
                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
			
		||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
 | 
			
		||||
                loss = self.criterion(prediction, tar)
 | 
			
		||||
                err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
			
		||||
                derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
 | 
			
		||||
                derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
			
		||||
                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
			
		||||
                total_loss += loss.item()
 | 
			
		||||
                derrs.append(derr/np.prod(tar.size()).item())
 | 
			
		||||
@ -226,9 +230,7 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
        all_true = []
 | 
			
		||||
        all_predicted = []
 | 
			
		||||
        all_deltas = []
 | 
			
		||||
        compression_ratios = []
 | 
			
		||||
        exact_matches = 0
 | 
			
		||||
        total_sequences = 0
 | 
			
		||||
        all_steps = []
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            min_length = min([len(seq) for seq in self.test_data])
 | 
			
		||||
@ -261,11 +263,13 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
 | 
			
		||||
                    stacked_segment = torch.stack([inputs] + peer_segments).to(device)
 | 
			
		||||
                    stacked_segments.append(stacked_segment)
 | 
			
		||||
                    target = lead_data[i + self.input_size + 1]
 | 
			
		||||
                    target = lead_data[i + self.input_size]
 | 
			
		||||
                    targets.append(target)
 | 
			
		||||
                    last = lead_data[i + self.input_size]
 | 
			
		||||
                    last = lead_data[i + self.input_size - 1]
 | 
			
		||||
                    lasts.append(last)
 | 
			
		||||
 | 
			
		||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
			
		||||
 | 
			
		||||
                inp = torch.stack(stacked_segments) / self.value_scale
 | 
			
		||||
                feat = self.feat(inp)
 | 
			
		||||
                latents = self.projector(feat)
 | 
			
		||||
@ -276,11 +280,15 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
                region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
			
		||||
                prediction = self.predictor(region_latent) * self.value_scale
 | 
			
		||||
 | 
			
		||||
                if self.delta_shift:
 | 
			
		||||
                    prediction = prediction + las
 | 
			
		||||
 | 
			
		||||
                tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)
 | 
			
		||||
                las = torch.tensor(lasts, dtype=torch.float32).unsqueeze(-1).numpy()
 | 
			
		||||
                loss = self.criterion(prediction, tar)
 | 
			
		||||
                err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
			
		||||
                derr = np.sum(np.abs(las - tar.cpu().detach().numpy()))
 | 
			
		||||
                delta = prediction.cpu().detach().numpy() - tar.cpu().detach().numpy()
 | 
			
		||||
                err = np.sum(np.abs(delta))
 | 
			
		||||
                derr = np.sum(np.abs(las.cpu().detach().numpy() - tar.cpu().detach().numpy()))
 | 
			
		||||
                step = las.cpu().detach().numpy() - tar.cpu().detach().numpy()
 | 
			
		||||
                rel = err / np.sum(tar.cpu().detach().numpy())
 | 
			
		||||
                total_loss += loss.item()
 | 
			
		||||
                derrs.append(derr / np.prod(tar.size()).item())
 | 
			
		||||
@ -289,13 +297,15 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
 | 
			
		||||
                all_true.extend(tar.cpu().numpy())
 | 
			
		||||
                all_predicted.extend(prediction.cpu().numpy())
 | 
			
		||||
                all_deltas.extend((tar.cpu().numpy() - prediction.cpu().numpy()).tolist())
 | 
			
		||||
                all_deltas.extend(delta.tolist())
 | 
			
		||||
                all_steps.extend(step.tolist())
 | 
			
		||||
 | 
			
		||||
                if self.full_compression:
 | 
			
		||||
                    raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
 | 
			
		||||
                    comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
 | 
			
		||||
                    ratio = raw_l / comp_l
 | 
			
		||||
                    wandb.log({"eval/ratio": ratio}, step=epoch)
 | 
			
		||||
        if self.full_compression:
 | 
			
		||||
            self.encoder.build_model(delta_samples=delta, **self.bitstream_encoder_config)
 | 
			
		||||
            raw_l = len(refuckify(np.concatenate(all_true)).astype(np.int16))*16
 | 
			
		||||
            comp_l = len(self.encoder.encode(np.concatenate(all_deltas)))
 | 
			
		||||
            ratio = raw_l / comp_l
 | 
			
		||||
            wandb.log({"eval/ratio": ratio}, step=epoch)
 | 
			
		||||
 | 
			
		||||
        avg_loss = total_loss / len(self.test_data)
 | 
			
		||||
        tot_err = sum(errs) / len(errs)
 | 
			
		||||
@ -308,7 +318,7 @@ class SpikeRunner(Slate_Runner):
 | 
			
		||||
 | 
			
		||||
        # Visualize predictions
 | 
			
		||||
        #visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=1953, name='0.1s')
 | 
			
		||||
        img = visualize_prediction(all_true, all_predicted, all_deltas, epoch=epoch, num_points=195)
 | 
			
		||||
        img = visualize_prediction(all_true, all_predicted, all_deltas, all_steps, epoch=epoch, num_points=195)
 | 
			
		||||
        try:
 | 
			
		||||
            wandb.log({f"Prediction vs True Data 0.01s": wandb.Image(img)}, step=epoch)
 | 
			
		||||
        except:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user