Implemented ResNet style for MiddleOut and configurable peer gradients

This commit is contained in:
Dominik Moritz Roth 2024-05-26 15:58:45 +02:00
parent 6076aaf36c
commit 5a3d491109
2 changed files with 9 additions and 1 deletions

View File

@ -68,6 +68,7 @@ class SpikeRunner(Slate_Runner):
self.learning_rate = slate.consume(training_config, 'learning_rate')
self.eval_freq = slate.consume(training_config, 'eval_freq')
self.save_path = slate.consume(training_config, 'save_path')
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
# Evaluation parameter
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
@ -143,6 +144,8 @@ class SpikeRunner(Slate_Runner):
my_latent = latents[:, 0, :]
peer_latents = latents[:, 1:, :]
if not self.peer_gradients:
peer_latents = peer_latents.detach()
# Pass through MiddleOut
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))

View File

@ -80,10 +80,13 @@ class LatentFourierProjector(nn.Module):
return latent
class MiddleOut(nn.Module):
def __init__(self, latent_size, region_latent_size, num_peers):
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
super(MiddleOut, self).__init__()
if residual:
assert latent_size == region_latent_size
self.num_peers = num_peers
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
self.residual = residual
def forward(self, my_latent, peer_latents, peer_metrics):
new_latents = []
@ -95,6 +98,8 @@ class MiddleOut(nn.Module):
new_latents = torch.stack(new_latents)
averaged_latent = torch.mean(new_latents, dim=0)
if self.residual:
return my_latent - averaged_latent
return averaged_latent
class Predictor(nn.Module):