diff --git a/models.py b/models.py index cf9971d..caadb72 100644 --- a/models.py +++ b/models.py @@ -51,13 +51,15 @@ class MiddleOut(nn.Module): new_latents = [] for p in range(peer_latents.shape[-2]): peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p] + import pdb + pdb.set_trace() combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1) new_latent = self.fc(combined_input) - new_latents.append(new_latent) + new_latents.append(new_latent * correlation.unsqueeze(1)) new_latents = torch.stack(new_latents) averaged_latent = torch.mean(new_latents, dim=0) - return averaged_latent + return my_latent - averaged_latent class Predictor(nn.Module): def __init__(self, output_size, layer_shapes, activations):