diff --git a/models.py b/models.py index 84918df..8d999f5 100644 --- a/models.py +++ b/models.py @@ -99,7 +99,9 @@ class MiddleOut(nn.Module): peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p] combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1) new_latent = self.fc(combined_input) - new_latents.append(new_latent * metric.unsqueeze(1)) + if self.residual: + new_latent = new_latent * metric.unsqueeze(1) + new_latents.append(new_latent) new_latents = torch.stack(new_latents) averaged_latent = torch.mean(new_latents, dim=0)