Implemented ResNet style for MiddleOut and configurable peer gradients
This commit is contained in:
parent
6076aaf36c
commit
5a3d491109
3
main.py
3
main.py
@ -68,6 +68,7 @@ class SpikeRunner(Slate_Runner):
|
||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||
self.save_path = slate.consume(training_config, 'save_path')
|
||||
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
|
||||
|
||||
# Evaluation parameter
|
||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||
@ -143,6 +144,8 @@ class SpikeRunner(Slate_Runner):
|
||||
|
||||
my_latent = latents[:, 0, :]
|
||||
peer_latents = latents[:, 1:, :]
|
||||
if not self.peer_gradients:
|
||||
peer_latents = peer_latents.detach()
|
||||
|
||||
# Pass through MiddleOut
|
||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||
|
@ -80,10 +80,13 @@ class LatentFourierProjector(nn.Module):
|
||||
return latent
|
||||
|
||||
class MiddleOut(nn.Module):
|
||||
def __init__(self, latent_size, region_latent_size, num_peers):
|
||||
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
||||
super(MiddleOut, self).__init__()
|
||||
if residual:
|
||||
assert latent_size == region_latent_size
|
||||
self.num_peers = num_peers
|
||||
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, my_latent, peer_latents, peer_metrics):
|
||||
new_latents = []
|
||||
@ -95,6 +98,8 @@ class MiddleOut(nn.Module):
|
||||
|
||||
new_latents = torch.stack(new_latents)
|
||||
averaged_latent = torch.mean(new_latents, dim=0)
|
||||
if self.residual:
|
||||
return my_latent - averaged_latent
|
||||
return averaged_latent
|
||||
|
||||
class Predictor(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user