Better ResNets
This commit is contained in:
		
							parent
							
								
									90af40b317
								
							
						
					
					
						commit
						16ba578737
					
				| @ -99,7 +99,9 @@ class MiddleOut(nn.Module): | ||||
|             peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p] | ||||
|             combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1) | ||||
|             new_latent = self.fc(combined_input) | ||||
|             new_latents.append(new_latent * metric.unsqueeze(1)) | ||||
|             if self.residual: | ||||
|                 new_latent = new_latent * metric.unsqueeze(1) | ||||
|             new_latents.append(new_latent) | ||||
|          | ||||
|         new_latents = torch.stack(new_latents) | ||||
|         averaged_latent = torch.mean(new_latents, dim=0) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user