diff --git a/models.py b/models.py index caadb72..a84faa9 100644 --- a/models.py +++ b/models.py @@ -51,8 +51,6 @@ class MiddleOut(nn.Module): new_latents = [] for p in range(peer_latents.shape[-2]): peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p] - import pdb - pdb.set_trace() combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1) new_latent = self.fc(combined_input) new_latents.append(new_latent * correlation.unsqueeze(1))