From 2ce2e8c3846aae8344aee1370c47fee67d022039 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 26 May 2024 00:28:18 +0200 Subject: [PATCH] bug fixes --- main.py | 21 +++++++++++++++------ models.py | 9 +++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 97b96f9..747382b 100644 --- a/main.py +++ b/main.py @@ -13,6 +13,7 @@ from pycallgraph2.output import GraphvizOutput from slate import Slate, Slate_Runner device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +#device = 'cpu' class SpikeRunner(Slate_Runner): def setup(self, name): @@ -101,6 +102,8 @@ class SpikeRunner(Slate_Runner): for epoch in range(self.epochs): total_loss = 0 + errs = [] + rels = [] for batch_num in range(self.num_batches): # Create indices for training data and shuffle them @@ -115,7 +118,7 @@ class SpikeRunner(Slate_Runner): lead_data = self.train_data[idx][:min_length] # Slide a window over the data with overlap - stride = max(1, self.input_size // 8) # Ensuring stride is at least 1 + stride = max(1, self.input_size // 3) # Ensuring stride is at least 1 for i in range(0, len(lead_data) - self.input_size-1, stride): lead_segment = lead_data[i:i + self.input_size] inputs = torch.tensor(lead_segment, dtype=torch.float32).to(device) @@ -123,9 +126,9 @@ class SpikeRunner(Slate_Runner): # Collect the segments for the current lead and its peers peer_segments = [] for peer_idx in self.sorted_peer_indices[idx]: - peer_segment = self.train_data[peer_idx][i:i + self.input_size][:min_length] + peer_segment = self.train_data[peer_idx][i:i + self.input_size] peer_segments.append(torch.tensor(peer_segment, dtype=torch.float32).to(device)) - peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) # Shape: (num_peers) + peer_correlation = torch.tensor([self.correlation_matrix[idx, peer_idx] for peer_idx in self.sorted_peer_indices[idx]], dtype=torch.float32).to(device) peer_correlations.append(peer_correlation) # Stack the segments to form the batch @@ -145,13 +148,20 @@ class SpikeRunner(Slate_Runner): prediction = self.predictor(new_latent) # Calculate loss and backpropagate - loss = self.criterion(prediction, torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device)) + tar = torch.tensor(targets, dtype=torch.float32).unsqueeze(-1).to(device) + loss = self.criterion(prediction, tar) + err = np.sum(np.abs(prediction.cpu().detach().numpy() - tar.cpu().detach().numpy())) + rel = err / np.sum(tar.cpu().detach().numpy()) total_loss += loss.item() + errs.append(err.item()) + rels.append(rel.item()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() - wandb.log({"epoch": epoch, "loss": total_loss}, step=epoch) + tot_err = sum(errs)/len(errs) + tot_rel = sum(rels)/len(rels) + wandb.log({"epoch": epoch, "loss": total_loss, "err": tot_err, "rel": tot_rel}, step=epoch) print(f'Epoch {epoch + 1}/{self.epochs}, Loss: {total_loss}') if self.eval_freq != -1 and (epoch + 1) % self.eval_freq == 0: @@ -281,7 +291,6 @@ class SpikeRunner(Slate_Runner): print('Evaluation done for this epoch.') return avg_loss - def save_models(self, epoch): return print('Saving models...') diff --git a/models.py b/models.py index a84faa9..f973584 100644 --- a/models.py +++ b/models.py @@ -37,8 +37,9 @@ class LatentRNNProjector(nn.Module): self.latent_size = latent_size def forward(self, x): - out, _ = self.rnn(x) - latent = self.fc(out) + batch_1, batch_2, timesteps = x.size() + out, _ = self.rnn(x.view(batch_1 * batch_2, timesteps)) + latent = self.fc(out).view(batch_1, batch_2, self.latent_size) return latent class MiddleOut(nn.Module): @@ -57,7 +58,7 @@ class MiddleOut(nn.Module): new_latents = torch.stack(new_latents) averaged_latent = torch.mean(new_latents, dim=0) - return my_latent - averaged_latent + return averaged_latent class Predictor(nn.Module): def __init__(self, output_size, layer_shapes, activations): @@ -73,4 +74,4 @@ class Predictor(nn.Module): self.fc = nn.Sequential(*layers) def forward(self, latent): - return self.fc(latent) + return self.fc(latent) \ No newline at end of file