diff --git a/train.py b/train.py index 5c41e90..fc6c9a1 100644 --- a/train.py +++ b/train.py @@ -306,7 +306,7 @@ def main(argv=None): } # Model preparation - module = PetriDashNet( + module = PetriDishNet( n_in=args.n_elems, n_neurons=args.n_hidden, n_out=1, diff --git a/utils.py b/utils.py index e4fdf41..27f11da 100644 --- a/utils.py +++ b/utils.py @@ -54,7 +54,7 @@ class ParityDataset(Dataset): y = (x == 1.0).sum() % 2 - return x, y + return x, torch.stack([y]) class PetriDishNet(nn.Module): @@ -110,8 +110,7 @@ class PetriDishNet(nn.Module): alpha=0.1, max_size=1e8) - self.ionDucts = ActivationSparsity(n_neurons, - alpha=0.1, + self.ionDucts = ActivationSparsity(alpha=0.1, beta=1.5, act_sparsity=0.65)