Remove buggy normalization

This commit is contained in:
Dominik Moritz Roth 2024-05-29 21:12:32 +02:00
parent 102ddb8c85
commit 404b59c8ba

View File

@ -134,9 +134,10 @@ class MiddleOut(nn.Module):
new_latents.append(new_latent) new_latents.append(new_latent)
new_latents = torch.stack(new_latents) new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
if self.residual: if self.residual:
return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2) return my_latent - averaged_latent
return torch.mean(new_latents, dim=0) return averaged_latent
class Predictor(nn.Module): class Predictor(nn.Module):
def __init__(self, region_latent_size, layer_shapes, activations): def __init__(self, region_latent_size, layer_shapes, activations):