Implemented ResNet style for MiddleOut and configurable peer gradients
This commit is contained in:
parent
6076aaf36c
commit
5a3d491109
3
main.py
3
main.py
@ -68,6 +68,7 @@ class SpikeRunner(Slate_Runner):
|
|||||||
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
self.learning_rate = slate.consume(training_config, 'learning_rate')
|
||||||
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
self.eval_freq = slate.consume(training_config, 'eval_freq')
|
||||||
self.save_path = slate.consume(training_config, 'save_path')
|
self.save_path = slate.consume(training_config, 'save_path')
|
||||||
|
self.peer_gradients = slate.consume(training_config, 'peer_gradients')
|
||||||
|
|
||||||
# Evaluation parameter
|
# Evaluation parameter
|
||||||
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False)
|
||||||
@ -143,6 +144,8 @@ class SpikeRunner(Slate_Runner):
|
|||||||
|
|
||||||
my_latent = latents[:, 0, :]
|
my_latent = latents[:, 0, :]
|
||||||
peer_latents = latents[:, 1:, :]
|
peer_latents = latents[:, 1:, :]
|
||||||
|
if not self.peer_gradients:
|
||||||
|
peer_latents = peer_latents.detach()
|
||||||
|
|
||||||
# Pass through MiddleOut
|
# Pass through MiddleOut
|
||||||
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics))
|
||||||
|
@ -80,10 +80,13 @@ class LatentFourierProjector(nn.Module):
|
|||||||
return latent
|
return latent
|
||||||
|
|
||||||
class MiddleOut(nn.Module):
|
class MiddleOut(nn.Module):
|
||||||
def __init__(self, latent_size, region_latent_size, num_peers):
|
def __init__(self, latent_size, region_latent_size, num_peers, residual=False):
|
||||||
super(MiddleOut, self).__init__()
|
super(MiddleOut, self).__init__()
|
||||||
|
if residual:
|
||||||
|
assert latent_size == region_latent_size
|
||||||
self.num_peers = num_peers
|
self.num_peers = num_peers
|
||||||
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size)
|
||||||
|
self.residual = residual
|
||||||
|
|
||||||
def forward(self, my_latent, peer_latents, peer_metrics):
|
def forward(self, my_latent, peer_latents, peer_metrics):
|
||||||
new_latents = []
|
new_latents = []
|
||||||
@ -95,6 +98,8 @@ class MiddleOut(nn.Module):
|
|||||||
|
|
||||||
new_latents = torch.stack(new_latents)
|
new_latents = torch.stack(new_latents)
|
||||||
averaged_latent = torch.mean(new_latents, dim=0)
|
averaged_latent = torch.mean(new_latents, dim=0)
|
||||||
|
if self.residual:
|
||||||
|
return my_latent - averaged_latent
|
||||||
return averaged_latent
|
return averaged_latent
|
||||||
|
|
||||||
class Predictor(nn.Module):
|
class Predictor(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user