Remove buggy normalization
This commit is contained in:
parent
102ddb8c85
commit
404b59c8ba
@ -134,9 +134,10 @@ class MiddleOut(nn.Module):
|
||||
new_latents.append(new_latent)
|
||||
|
||||
new_latents = torch.stack(new_latents)
|
||||
averaged_latent = torch.mean(new_latents, dim=0)
|
||||
if self.residual:
|
||||
return my_latent - torch.sum(new_latents, dim=0) / torch.sum(peer_metrics, dim=-2)
|
||||
return torch.mean(new_latents, dim=0)
|
||||
return my_latent - averaged_latent
|
||||
return averaged_latent
|
||||
|
||||
class Predictor(nn.Module):
|
||||
def __init__(self, region_latent_size, layer_shapes, activations):
|
||||
|
Loading…
Reference in New Issue
Block a user