From 16ba5787375c8509973329f3e910eed12c84a5c8 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 26 May 2024 23:56:02 +0200 Subject: [PATCH] Better ResNets --- models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index 84918df..8d999f5 100644 --- a/models.py +++ b/models.py @@ -99,7 +99,9 @@ class MiddleOut(nn.Module): peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p] combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1) new_latent = self.fc(combined_input) - new_latents.append(new_latent * metric.unsqueeze(1)) + if self.residual: + new_latent = new_latent * metric.unsqueeze(1) + new_latents.append(new_latent) new_latents = torch.stack(new_latents) averaged_latent = torch.mean(new_latents, dim=0)