diff --git a/models.py b/models.py index fcf04ea..9c2dc2f 100644 --- a/models.py +++ b/models.py @@ -134,10 +134,9 @@ class MiddleOut(nn.Module): new_latents.append(new_latent) new_latents = torch.stack(new_latents) - averaged_latent = torch.mean(new_latents, dim=0) if self.residual: - return my_latent - averaged_latent - return averaged_latent + return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2) + return torch.mean(new_latents, dim=0) class Predictor(nn.Module): def __init__(self, region_latent_size, layer_shapes, activations):