Better ResNets

This commit is contained in:
Dominik Moritz Roth 2024-05-26 23:56:02 +02:00
parent 90af40b317
commit 16ba578737

View File

@ -99,7 +99,9 @@ class MiddleOut(nn.Module):
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1)
new_latent = self.fc(combined_input)
new_latents.append(new_latent * metric.unsqueeze(1))
if self.residual:
new_latent = new_latent * metric.unsqueeze(1)
new_latents.append(new_latent)
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)