From 95769cdeff813f83bc18fc55b71ef2ca4e3a1eb7 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Tue, 28 May 2024 16:09:02 +0200 Subject: [PATCH] Better magnitude normalization for ResNet operation mode in MiddleOut --- models.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/models.py b/models.py index fcf04ea..9c2dc2f 100644 --- a/models.py +++ b/models.py @@ -134,10 +134,9 @@ class MiddleOut(nn.Module): new_latents.append(new_latent) new_latents = torch.stack(new_latents) - averaged_latent = torch.mean(new_latents, dim=0) if self.residual: - return my_latent - averaged_latent - return averaged_latent + return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2) + return torch.mean(new_latents, dim=0) class Predictor(nn.Module): def __init__(self, region_latent_size, layer_shapes, activations):