diff --git a/main.py b/main.py index f01f501..6b6e274 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,6 @@ from pycallgraph2 import PyCallGraph from pycallgraph2.output import GraphvizOutput from slate import Slate, Slate_Runner -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -#device = 'cpu' class SpikeRunner(Slate_Runner): def setup(self, name): @@ -49,6 +47,10 @@ class SpikeRunner(Slate_Runner): latent_size = slate.consume(config, 'latent_projector.latent_size') input_size = slate.consume(config, 'latent_projector.input_size') region_latent_size = slate.consume(config, 'middle_out.region_latent_size') + device = slate.consume(training_config, 'device') + if device == 'auto': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device if latent_projector_type == 'fc': self.projector = LatentFCProjector(latent_size=latent_size, input_size=input_size, **slate.consume(config, 'latent_projector', expand=True)).to(device) @@ -68,7 +70,8 @@ class SpikeRunner(Slate_Runner): self.learning_rate = slate.consume(training_config, 'learning_rate') self.eval_freq = slate.consume(training_config, 'eval_freq') self.save_path = slate.consume(training_config, 'save_path') - self.peer_gradients = slate.consume(training_config, 'peer_gradients') + self.peer_gradients_factor = float(slate.consume(training_config, 'peer_gradients_factor', 1.0)) + self.value_scale = slate.consume(training_config, 'value_scale') # Evaluation parameter self.full_compression = slate.consume(config, 'evaluation.full_compression', default=False) @@ -98,8 +101,9 @@ class SpikeRunner(Slate_Runner): self.train_model() def train_model(self): + device = self.device min_length = min([len(seq) for seq in self.train_data]) - + best_test_score = float('inf') for epoch in range(self.epochs): @@ -121,7 +125,8 @@ class SpikeRunner(Slate_Runner): # Slide a window over the data with overlap stride = max(1, self.input_size // 3) # Ensuring stride is at least 1 - for i in range(0, len(lead_data) - self.input_size-1, stride): + offset = np.random.randint(0, stride) + for i in range(offset, len(lead_data) - self.input_size-1-offset, stride): lead_segment = lead_data[i:i + self.input_size] inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) @@ -140,16 +145,22 @@ class SpikeRunner(Slate_Runner): targets.append(target) # Pass the batch through the projector - latents = self.projector(torch.stack(stacked_segments)) + latents = self.projector(torch.stack(stacked_segments)/self.value_scale) my_latent = latents[:, 0, :] peer_latents = latents[:, 1:, :] - if not self.peer_gradients: + + # Scale gradients during backwards pass as configured + if self.peer_gradients_factor == 1.0: + pass + elif self.peer_gradients_factor == 0.0: peer_latents = peer_latents.detach() + else: + peer_latents.register_hook(lambda grad: grad*self.peer_gradients_factor) # Pass through MiddleOut - new_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) - prediction = self.predictor(new_latent) + region_latent = self.middle_out(my_latent, peer_latents, torch.stack(peer_metrics)) + prediction = self.predictor(region_latent)*self.value_scale # Calculate loss and backpropagate tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) @@ -157,15 +168,15 @@ class SpikeRunner(Slate_Runner): err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) rel = err / np.sum(tar.cpu().detach().numpy()) total_loss += loss.item() - errs.append(err.item()) + errs.append(err/np.prod(tar.size()).item()) rels.append(rel.item()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() tot_err = sum(errs)/len(errs) - tot_rel = sum(rels)/len(rels) - wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch) + approx_ratio = 1/(sum(rels)/len(rels)) + wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "approx_ratio": approx_ratio}, step=epoch) print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: @@ -191,6 +202,8 @@ class SpikeRunner(Slate_Runner): def evaluate_model(self, epoch): print('Evaluating model...') + device = self.device + self.projector.eval() self.middle_out.eval() self.predictor.eval() diff --git a/models.py b/models.py index 9569a92..84918df 100644 --- a/models.py +++ b/models.py @@ -53,30 +53,31 @@ class LatentFourierProjector(nn.Module): super(LatentFourierProjector, self).__init__() self.fourier_transform = FourierTransformLayer() layers = [] + if pass_raw_len is None: pass_raw_len = input_size else: assert pass_raw_len <= input_size + in_features = pass_raw_len + (input_size // 2 + 1) * 2 # (input_size // 2 + 1) real + imaginary parts for i, out_features in enumerate(layer_shapes): layers.append(nn.Linear(in_features, out_features)) if activations[i] != 'None': layers.append(get_activation(activations[i])) in_features = out_features + layers.append(nn.Linear(in_features, latent_size)) self.fc = nn.Sequential(*layers) self.latent_size = latent_size self.pass_raw_len = pass_raw_len def forward(self, x): - # Apply Fourier Transform - x_fft = self.fourier_transform(x) - # Separate real and imaginary parts and combine them + batch_1, batch_2, timesteps = x.size() + x_fft = self.fourier_transform(x.view(batch_1 * batch_2, timesteps)) x_fft_real_imag = torch.cat((x_fft.real, x_fft.imag), dim=-1) - # Combine part of the raw input with Fourier features - combined_input = torch.cat([x[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1) - # Process through fully connected layers + combined_input = torch.cat([x.view(batch_1 * batch_2, timesteps)[:, -self.pass_raw_len:], x_fft_real_imag], dim=-1) latent = self.fc(combined_input) + latent = latent.view(batch_1, batch_2, self.latent_size) return latent class MiddleOut(nn.Module): @@ -84,11 +85,15 @@ class MiddleOut(nn.Module): super(MiddleOut, self).__init__() if residual: assert latent_size == region_latent_size + if num_peers == 0: + assert latent_size == region_latent_size self.num_peers = num_peers self.fc = nn.Linear(latent_size * 2 + 1, region_latent_size) self.residual = residual def forward(self, my_latent, peer_latents, peer_metrics): + if self.num_peers == 0: + return my_latent new_latents = [] for p in range(peer_latents.shape[-2]): peer_latent, metric = peer_latents[:, p, :], peer_metrics[:, p]