From 5a3d49110932e8f1d51266272e61db4eaeb120fc Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 26 May 2024 15:58:45 +0200 Subject: [PATCH] Implemented ResNet style for MiddleOut and configurable peer gradients --- main.py | 3 +++ models.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 04e0202..f01f501 100644 --- a/main.py +++ b/main.py @@ -68,6 +68,7 @@ class SpikeRunner(Slate_Runner): self.learning_rate = slate.consume(training_config, 'learning_rate') self.eval_freq = slate.consume(training_config, 'eval_freq') self.save_path = slate.consume(training_config, 'save_path') + self.peer_gradients = slate.consume(training_config, 'peer_gradients') # Evaluation parameter self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) @@ -143,6 +144,8 @@ class SpikeRunner(Slate_Runner): my_latent = latents[:, 0, :] peer_latents = latents[:, 1:, :] + if not self.peer_gradients: + peer_latents = peer_latents.detach() # Pass through MiddleOut new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) diff --git a/models.py b/models.py index 7f5e98e..9569a92 100644 --- a/models.py +++ b/models.py @@ -80,10 +80,13 @@ class LatentFourierProjector(nn.Module): return latent class MiddleOut(nn.Module): - def __init__(self, latent_size, region_latent_size, num_peers): + def __init__(self, latent_size, region_latent_size, num_peers, residual=False): super(MiddleOut, self).__init__() + if residual: + assert latent_size == region_latent_size self.num_peers = num_peers self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size) + self.residual = residual def forward(self, my_latent, peer_latents, peer_metrics): new_latents = [] @@ -95,6 +98,8 @@ class MiddleOut(nn.Module): new_latents = torch.stack(new_latents) averaged_latent = torch.mean(new_latents, dim=0) + if self.residual: + return my_latent - averaged_latent return averaged_latent class Predictor(nn.Module):