Implemented ResNet style for MiddleOut and configurable peer gradients
This commit is contained in:
		
							parent
							
								
									6076aaf36c
								
							
						
					
					
						commit
						5a3d491109
					
				
							
								
								
									
										3
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								main.py
									
									
									
									
									
								
							@ -68,6 +68,7 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
					        self.learning_rate = slate.consume(training_config, 'learning_rate')
 | 
				
			||||||
        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
					        self.eval_freq = slate.consume(training_config, 'eval_freq')
 | 
				
			||||||
        self.save_path = slate.consume(training_config, 'save_path')
 | 
					        self.save_path = slate.consume(training_config, 'save_path')
 | 
				
			||||||
 | 
					        self.peer_gradients = slate.consume(training_config, 'peer_gradients')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Evaluation parameter
 | 
					        # Evaluation parameter
 | 
				
			||||||
        self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
 | 
					        self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
 | 
				
			||||||
@ -143,6 +144,8 @@ class SpikeRunner(Slate_Runner):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                my_latent = latents[:, 0, :]
 | 
					                my_latent = latents[:, 0, :]
 | 
				
			||||||
                peer_latents = latents[:, 1:, :]
 | 
					                peer_latents = latents[:, 1:, :]
 | 
				
			||||||
 | 
					                if not self.peer_gradients:
 | 
				
			||||||
 | 
					                    peer_latents = peer_latents.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # Pass through MiddleOut
 | 
					                # Pass through MiddleOut
 | 
				
			||||||
                new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
					                new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
 | 
				
			||||||
 | 
				
			|||||||
@ -80,10 +80,13 @@ class LatentFourierProjector(nn.Module):
 | 
				
			|||||||
        return latent
 | 
					        return latent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MiddleOut(nn.Module):
 | 
					class MiddleOut(nn.Module):
 | 
				
			||||||
    def __init__(self, latent_size, region_latent_size, num_peers):
 | 
					    def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
 | 
				
			||||||
        super(MiddleOut, self).__init__()
 | 
					        super(MiddleOut, self).__init__()
 | 
				
			||||||
 | 
					        if residual:
 | 
				
			||||||
 | 
					            assert latent_size == region_latent_size
 | 
				
			||||||
        self.num_peers = num_peers
 | 
					        self.num_peers = num_peers
 | 
				
			||||||
        self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
 | 
					        self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
 | 
				
			||||||
 | 
					        self.residual = residual
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, my_latent, peer_latents, peer_metrics):
 | 
					    def forward(self, my_latent, peer_latents, peer_metrics):
 | 
				
			||||||
        new_latents = []
 | 
					        new_latents = []
 | 
				
			||||||
@ -95,6 +98,8 @@ class MiddleOut(nn.Module):
 | 
				
			|||||||
        
 | 
					        
 | 
				
			||||||
        new_latents = torch.stack(new_latents)
 | 
					        new_latents = torch.stack(new_latents)
 | 
				
			||||||
        averaged_latent = torch.mean(new_latents, dim=0)
 | 
					        averaged_latent = torch.mean(new_latents, dim=0)
 | 
				
			||||||
 | 
					        if self.residual:
 | 
				
			||||||
 | 
					            return my_latent - averaged_latent
 | 
				
			||||||
        return averaged_latent
 | 
					        return averaged_latent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Predictor(nn.Module):
 | 
					class Predictor(nn.Module):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user