Better NN inits
This commit is contained in:
parent
ec9794d886
commit
b9d531a97b
@ -51,13 +51,15 @@ class MiddleOut(nn.Module):
|
|||||||
new_latents = []
|
new_latents = []
|
||||||
for p in range(peer_latents.shape[-2]):
|
for p in range(peer_latents.shape[-2]):
|
||||||
peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
|
peer_latent, correlation = peer_latents[:, p, :], peer_correlations[:, p]
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
|
combined_input = torch.cat((my_latent, peer_latent, correlation.unsqueeze(1)), dim=-1)
|
||||||
new_latent = self.fc(combined_input)
|
new_latent = self.fc(combined_input)
|
||||||
new_latents.append(new_latent)
|
new_latents.append(new_latent * correlation.unsqueeze(1))
|
||||||
|
|
||||||
new_latents = torch.stack(new_latents)
|
new_latents = torch.stack(new_latents)
|
||||||
averaged_latent = torch.mean(new_latents, dim=0)
|
averaged_latent = torch.mean(new_latents, dim=0)
|
||||||
return averaged_latent
|
return my_latent - averaged_latent
|
||||||
|
|
||||||
class Predictor(nn.Module):
|
class Predictor(nn.Module):
|
||||||
def __init__(self, output_size, layer_shapes, activations):
|
def __init__(self, output_size, layer_shapes, activations):
|
||||||
|
Loading…
Reference in New Issue
Block a user