From 404b59c8ba2d57c1f4e025d8e16bf4f8f8ddbb0b Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 May 2024 21:12:32 +0200 Subject: [PATCH] Remove buggy normalization --- models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 9c2dc2f..fcf04ea 100644 --- a/models.py +++ b/models.py @@ -134,9 +134,10 @@ class MiddleOut(nn.Module): new_latents.append(new_latent) new_latents = torch.stack(new_latents) + averaged_latent = torch.mean(new_latents, dim=0) if self.residual: - return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2) - return torch.mean(new_latents, dim=0) + return my_latent - averaged_latent + return averaged_latent class Predictor(nn.Module): def __init__(self, region_latent_size, layer_shapes, activations):