Remove buggy normalization
This commit is contained in:
		
							parent
							
								
									102ddb8c85
								
							
						
					
					
						commit
						404b59c8ba
					
				@ -134,9 +134,10 @@ class MiddleOut(nn.Module):
 | 
				
			|||||||
            new_latents.append(new_latent)
 | 
					            new_latents.append(new_latent)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        new_latents = torch.stack(new_latents)
 | 
					        new_latents = torch.stack(new_latents)
 | 
				
			||||||
 | 
					        averaged_latent = torch.mean(new_latents, dim=0)
 | 
				
			||||||
        if self.residual:
 | 
					        if self.residual:
 | 
				
			||||||
            return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2)
 | 
					            return my_latent - averaged_latent
 | 
				
			||||||
        return torch.mean(new_latents, dim=0)
 | 
					        return averaged_latent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Predictor(nn.Module):
 | 
					class Predictor(nn.Module):
 | 
				
			||||||
    def __init__(self, region_latent_size, layer_shapes, activations):
 | 
					    def __init__(self, region_latent_size, layer_shapes, activations):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user