Made progress, but not working
This commit is contained in:
parent
2c9de61fee
commit
ba26666329
12
train.py
12
train.py
@ -59,16 +59,16 @@ def evaluate(dataloader, module):
|
||||
y_true_batch = y_true_batch.to(device, dtype) # (batch_size,)
|
||||
|
||||
y_pred_batch, p, halting_step = module(x_batch)
|
||||
y_halted_batch = y_pred_batch.gather(
|
||||
y_halted_batch = y_pred_batch.index_select(
|
||||
dim=0,
|
||||
index=halting_step[None, :] - 1,
|
||||
index=halting_step - 1,
|
||||
)[
|
||||
0
|
||||
] # (batch_size,)
|
||||
] # (batch_size,)
|
||||
|
||||
# Computing single metrics (mean over samples in the batch)
|
||||
accuracy_halted = (
|
||||
((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean()
|
||||
abs(y_halted_batch - y_true_batch).to(torch.float32).mean()
|
||||
)
|
||||
|
||||
metrics_single_["accuracy_halted"].append(accuracy_halted)
|
||||
@ -78,9 +78,9 @@ def evaluate(dataloader, module):
|
||||
|
||||
# Computing per step metrics (mean over samples in the batch)
|
||||
accuracy = (
|
||||
((y_pred_batch > 0) == y_true_batch[None, :])
|
||||
(abs(y_pred_batch - y_true_batch[None, :]))
|
||||
.to(torch.float32)
|
||||
.mean(dim=1)
|
||||
.mean(dim=1).mean(dim=1)
|
||||
)
|
||||
|
||||
metrics_per_step_["accuracy"].append(accuracy)
|
||||
|
11
utils.py
11
utils.py
@ -102,17 +102,21 @@ class PetriDishNet(nn.Module):
|
||||
self.allow_halting = allow_halting
|
||||
|
||||
self.synapses = SparseLinear(n_neurons, n_neurons,
|
||||
sparsity=0.9,
|
||||
sparsity=0.8,
|
||||
#sparsity=0.1,
|
||||
small_world=True,
|
||||
dynamic=True,
|
||||
deltaT=6000,
|
||||
#deltaT=6000,
|
||||
deltaT=250,
|
||||
Tend=150000,
|
||||
#Tend=5000,
|
||||
alpha=0.1,
|
||||
max_size=1e8)
|
||||
|
||||
self.ionDucts = ActivationSparsity(alpha=0.1,
|
||||
beta=1.5,
|
||||
act_sparsity=0.65)
|
||||
act_sparsity=0.75)
|
||||
#act_sparsity=0.10)
|
||||
|
||||
def forward(self, x):
|
||||
"""Run forward pass.
|
||||
@ -164,6 +168,7 @@ class PetriDishNet(nn.Module):
|
||||
else:
|
||||
lambda_n = torch.sigmoid(state[:, 0])
|
||||
|
||||
|
||||
# Store releavant outputs
|
||||
y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
|
||||
p_list.append(un_halted_prob * lambda_n)
|
||||
|
Loading…
Reference in New Issue
Block a user