diff --git a/models.py b/models.py index 9c2dc2f..fcf04ea 100644 --- a/models.py +++ b/models.py @@ -134,9 +134,10 @@ class MiddleOut(nn.Module): new_latents.append(new_latent) new_latents = torch.stack(new_latents) + averaged_latent = torch.mean(new_latents, dim=0) if self.residual: - return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2) - return torch.mean(new_latents, dim=0) + return my_latent - averaged_latent + return averaged_latent class Predictor(nn.Module): def __init__(self, region_latent_size, layer_shapes, activations):