Better ResNets
This commit is contained in:
parent
90af40b317
commit
16ba578737
@ -99,7 +99,9 @@ class MiddleOut(nn.Module):
|
||||
peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]
|
||||
combined_input = torch.cat((my_latent, peer_latent, metric.unsqueeze(1)), dim=-1)
|
||||
new_latent = self.fc(combined_input)
|
||||
new_latents.append(new_latent * metric.unsqueeze(1))
|
||||
if self.residual:
|
||||
new_latent = new_latent * metric.unsqueeze(1)
|
||||
new_latents.append(new_latent)
|
||||
|
||||
new_latents = torch.stack(new_latents)
|
||||
averaged_latent = torch.mean(new_latents, dim=0)
|
||||
|
Loading…
Reference in New Issue
Block a user