Better magnitude normalization for ResNet operation mode in MiddleOut

This commit is contained in:
Dominik Moritz Roth 2024-05-28 16:09:02 +02:00
parent 103cb0baba
commit 95769cdeff

View File

@ -134,10 +134,9 @@ class MiddleOut(nn.Module):
new_latents.append(new_latent)
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
if self.residual:
return my_latent - averaged_latent
return averaged_latent
return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2)
return torch.mean(new_latents, dim=0)
class Predictor(nn.Module):
def __init__(self, region_latent_size, layer_shapes, activations):