Better NN inits

This commit is contained in:
Dominik Moritz Roth 2024-05-25 21:39:47 +02:00
parent ec9794d886
commit b9d531a97b

View File

@ -51,13 +51,15 @@ class MiddleOut(nn.Module):
new_latents = []
for p in range(peer_latents.shape[-2]):
peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
import pdb
pdb.set_trace()
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
new_latent = self.fc(combined_input)
new_latents.append(new_latent)
new_latents.append(new_latent * correlation.unsqueeze(1))
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
return averaged_latent
return my_latent - averaged_latent
class Predictor(nn.Module):
def __init__(self, output_size, layer_shapes, activations):