From ba2666632981f30f727fbfe28a471431aa7ea628 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 24 Oct 2021 18:07:11 +0200 Subject: [PATCH] Made progress, but not working --- train.py | 12 ++++++------ utils.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index fc6c9a1..92ccb0a 100644 --- a/train.py +++ b/train.py @@ -59,16 +59,16 @@ def evaluate(dataloader, module): y_true_batch = y_true_batch.to(device, dtype) # (batch_size,) y_pred_batch, p, halting_step = module(x_batch) - y_halted_batch = y_pred_batch.gather( + y_halted_batch = y_pred_batch.index_select( dim=0, - index=halting_step[None, :] - 1, + index=halting_step - 1, )[ 0 - ] # (batch_size,) + ] # (batch_size,) # Computing single metrics (mean over samples in the batch) accuracy_halted = ( - ((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean() + abs(y_halted_batch - y_true_batch).to(torch.float32).mean() ) metrics_single_["accuracy_halted"].append(accuracy_halted) @@ -78,9 +78,9 @@ def evaluate(dataloader, module): # Computing per step metrics (mean over samples in the batch) accuracy = ( - ((y_pred_batch > 0) == y_true_batch[None, :]) + (abs(y_pred_batch - y_true_batch[None, :])) .to(torch.float32) - .mean(dim=1) + .mean(dim=1).mean(dim=1) ) metrics_per_step_["accuracy"].append(accuracy) diff --git a/utils.py b/utils.py index 27f11da..002f35a 100644 --- a/utils.py +++ b/utils.py @@ -102,17 +102,21 @@ class PetriDishNet(nn.Module): self.allow_halting = allow_halting self.synapses = SparseLinear(n_neurons, n_neurons, - sparsity=0.9, + sparsity=0.8, + #sparsity=0.1, small_world=True, dynamic=True, - deltaT=6000, + #deltaT=6000, + deltaT=250, Tend=150000, + #Tend=5000, alpha=0.1, max_size=1e8) self.ionDucts = ActivationSparsity(alpha=0.1, beta=1.5, - act_sparsity=0.65) + act_sparsity=0.75) + #act_sparsity=0.10) def forward(self, x): """Run forward pass. @@ -164,6 +168,7 @@ class PetriDishNet(nn.Module): else: lambda_n = torch.sigmoid(state[:, 0]) + # Store releavant outputs y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)]) p_list.append(un_halted_prob * lambda_n)