initial commit
This commit is contained in:
commit
63cd063c70
12
README.md
Normal file
12
README.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Project PetriDish
|
||||||
|
Quick and dirty PoC for the idea behind Project Neuromorph.
|
||||||
|
Combines PonderNet and SparseLinear.
|
||||||
|
PonderNet stolen from https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/pondernet
|
||||||
|
SparseLinear stolen from https://pypi.org/project/sparselinear/
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
We a neural network comprised of a set of neurons that are connected using a set of synapses. Neurons that are close to each other (es use a 1D-Distance metric (Project Neuromorph will have approximate euclidean distance)) have a higher chance to have a synaptic connection.
|
||||||
|
We train this net like normal; but we also allow the structure of the synapctic connections to change during training. (Number of neurons remains constant; this is also variable in Neuromorph; Neuromorph will also use more advanced algorithms to decide where to spawn new synapses)
|
||||||
|
In every firing-cycle only a fraction of all neurons are allowed to fire (highest output) all others are inhibited. (In Project Neuromorph this will be less strict; low firing-rates will have higher dropout-chances and we discurage firing thought an additional loss)
|
||||||
|
Based on the PonderNet-Architecture we allow our network to 'think' as long as it wants about a given problem (well, ok; there is a maximum amount of firing-cycles to make training possbile)
|
||||||
|
The input's and output's of the network have a fixed length. (Project Neuromorph will also allow variable-length outputs like in a RNN)
|
BIN
__pycache__/utils.cpython-39.pyc
Normal file
BIN
__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
398
train.py
Normal file
398
train.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
ParityDataset,
|
||||||
|
PetriDishNet,
|
||||||
|
ReconstructionLoss,
|
||||||
|
RegularizationLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(dataloader, module):
|
||||||
|
"""Compute relevant metrics.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataloader : DataLoader
|
||||||
|
Dataloader that yields batches of `x` and `y`.
|
||||||
|
|
||||||
|
module : PonderNet
|
||||||
|
Our pondering network.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
metrics_single : dict
|
||||||
|
Scalar metrics. The keys are names and the values are `torch.Tensor`.
|
||||||
|
These metrics are computed as mean values over the entire dataset.
|
||||||
|
|
||||||
|
metrics_per_step : dict
|
||||||
|
Per step metrics. The keys are names and the values are `torch.Tensor`
|
||||||
|
of shape `(max_steps,)`. These metrics are computed as mean values over
|
||||||
|
the entire dataset.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Imply device and dtype
|
||||||
|
param = next(module.parameters())
|
||||||
|
device, dtype = param.device, param.dtype
|
||||||
|
|
||||||
|
metrics_single_ = {
|
||||||
|
"accuracy_halted": [],
|
||||||
|
"halting_step": [],
|
||||||
|
}
|
||||||
|
metrics_per_step_ = {
|
||||||
|
"accuracy": [],
|
||||||
|
"p": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for x_batch, y_true_batch in dataloader:
|
||||||
|
x_batch = x_batch.to(device, dtype) # (batch_size, n_elems)
|
||||||
|
y_true_batch = y_true_batch.to(device, dtype) # (batch_size,)
|
||||||
|
|
||||||
|
y_pred_batch, p, halting_step = module(x_batch)
|
||||||
|
y_halted_batch = y_pred_batch.gather(
|
||||||
|
dim=0,
|
||||||
|
index=halting_step[None, :] - 1,
|
||||||
|
)[
|
||||||
|
0
|
||||||
|
] # (batch_size,)
|
||||||
|
|
||||||
|
# Computing single metrics (mean over samples in the batch)
|
||||||
|
accuracy_halted = (
|
||||||
|
((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_single_["accuracy_halted"].append(accuracy_halted)
|
||||||
|
metrics_single_["halting_step"].append(
|
||||||
|
halting_step.to(torch.float).mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computing per step metrics (mean over samples in the batch)
|
||||||
|
accuracy = (
|
||||||
|
((y_pred_batch > 0) == y_true_batch[None, :])
|
||||||
|
.to(torch.float32)
|
||||||
|
.mean(dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics_per_step_["accuracy"].append(accuracy)
|
||||||
|
metrics_per_step_["p"].append(p.mean(dim=1))
|
||||||
|
|
||||||
|
metrics_single = {
|
||||||
|
name: torch.stack(values).mean(dim=0).cpu().numpy()
|
||||||
|
for name, values in metrics_single_.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics_per_step = {
|
||||||
|
name: torch.stack(values).mean(dim=0).cpu().numpy()
|
||||||
|
for name, values in metrics_per_step_.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics_single, metrics_per_step
|
||||||
|
|
||||||
|
|
||||||
|
def plot_distributions(target, predicted):
|
||||||
|
"""Create a barplot.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
target, predicted : np.ndarray
|
||||||
|
Arrays of shape `(max_steps,)` representing the target and predicted
|
||||||
|
probability distributions.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
matplotlib.Figure
|
||||||
|
"""
|
||||||
|
support = list(range(1, len(target) + 1))
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(dpi=140)
|
||||||
|
|
||||||
|
ax.bar(
|
||||||
|
support,
|
||||||
|
target,
|
||||||
|
color="red",
|
||||||
|
label=f"Target - Geometric({target[0].item():.2f})",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.bar(
|
||||||
|
support,
|
||||||
|
predicted,
|
||||||
|
color="green",
|
||||||
|
width=0.4,
|
||||||
|
label="Predicted",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylim(0, 0.6)
|
||||||
|
ax.set_xticks(support)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_accuracy(accuracy):
|
||||||
|
"""Create a barplot representing accuracy over different halting steps.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
accuracy : np.array
|
||||||
|
1D array representing accuracy if we were to take the output after
|
||||||
|
the corresponding step.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
matplotlib.Figure
|
||||||
|
"""
|
||||||
|
support = list(range(1, len(accuracy) + 1))
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(dpi=140)
|
||||||
|
|
||||||
|
ax.bar(
|
||||||
|
support,
|
||||||
|
accuracy,
|
||||||
|
label="Accuracy over different steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
ax.set_xticks(support)
|
||||||
|
ax.legend()
|
||||||
|
ax.grid()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv=None):
|
||||||
|
"""CLI for training."""
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"log_folder",
|
||||||
|
type=str,
|
||||||
|
help="Folder where tensorboard logging is saved",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="Regularization loss coefficient",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d",
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
choices={"cpu", "cuda"},
|
||||||
|
default="cpu",
|
||||||
|
help="Device to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval-frequency",
|
||||||
|
type=int,
|
||||||
|
default=10_000,
|
||||||
|
help="Evaluation is run every `eval_frequency` steps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lambda-p",
|
||||||
|
type=float,
|
||||||
|
default=0.4,
|
||||||
|
help="True probability of success for a geometric distribution",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-iter",
|
||||||
|
type=int,
|
||||||
|
default=1_000_000,
|
||||||
|
help="Number of gradient steps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-elems",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of elements",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-hidden",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Number of hidden elements in the reccurent cell",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-nonzero",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=(None, None),
|
||||||
|
help="Lower and upper bound on nonzero elements in the training set",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-steps",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Maximum number of pondering steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
device = torch.device(args.device)
|
||||||
|
dtype = torch.float32
|
||||||
|
n_eval_samples = 1000
|
||||||
|
batch_size_eval = 50
|
||||||
|
|
||||||
|
if args.n_nonzero[0] is None and args.n_nonzero[1] is None:
|
||||||
|
threshold = int(0.3 * args.n_elems)
|
||||||
|
range_nonzero_easy = (1, threshold)
|
||||||
|
range_nonzero_hard = (args.n_elems - threshold, args.n_elems)
|
||||||
|
else:
|
||||||
|
range_nonzero_easy = (1, args.n_nonzero[1])
|
||||||
|
range_nonzero_hard = (args.n_nonzero[1] + 1, args.n_elems)
|
||||||
|
|
||||||
|
# Tensorboard
|
||||||
|
log_folder = pathlib.Path(args.log_folder)
|
||||||
|
writer = SummaryWriter(log_folder)
|
||||||
|
writer.add_text("parameters", json.dumps(vars(args)))
|
||||||
|
|
||||||
|
# Prepare data
|
||||||
|
dataloader_train = DataLoader(
|
||||||
|
ParityDataset(
|
||||||
|
n_samples=args.batch_size * args.n_iter,
|
||||||
|
n_elems=args.n_elems,
|
||||||
|
n_nonzero_min=args.n_nonzero[0],
|
||||||
|
n_nonzero_max=args.n_nonzero[1],
|
||||||
|
),
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
) # consider specifying `num_workers` for speedups
|
||||||
|
eval_dataloaders = {
|
||||||
|
"test": DataLoader(
|
||||||
|
ParityDataset(
|
||||||
|
n_samples=n_eval_samples,
|
||||||
|
n_elems=args.n_elems,
|
||||||
|
n_nonzero_min=args.n_nonzero[0],
|
||||||
|
n_nonzero_max=args.n_nonzero[1],
|
||||||
|
),
|
||||||
|
batch_size=batch_size_eval,
|
||||||
|
),
|
||||||
|
f"{range_nonzero_easy[0]}_{range_nonzero_easy[1]}": DataLoader(
|
||||||
|
ParityDataset(
|
||||||
|
n_samples=n_eval_samples,
|
||||||
|
n_elems=args.n_elems,
|
||||||
|
n_nonzero_min=range_nonzero_easy[0],
|
||||||
|
n_nonzero_max=range_nonzero_easy[1],
|
||||||
|
),
|
||||||
|
batch_size=batch_size_eval,
|
||||||
|
),
|
||||||
|
f"{range_nonzero_hard[0]}_{range_nonzero_hard[1]}": DataLoader(
|
||||||
|
ParityDataset(
|
||||||
|
n_samples=n_eval_samples,
|
||||||
|
n_elems=args.n_elems,
|
||||||
|
n_nonzero_min=range_nonzero_hard[0],
|
||||||
|
n_nonzero_max=range_nonzero_hard[1],
|
||||||
|
),
|
||||||
|
batch_size=batch_size_eval,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model preparation
|
||||||
|
module = PetriDashNet(
|
||||||
|
n_in=args.n_elems,
|
||||||
|
n_neurons=args.n_hidden,
|
||||||
|
n_out=1,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
)
|
||||||
|
module = module.to(device, dtype)
|
||||||
|
|
||||||
|
# Loss preparation
|
||||||
|
loss_rec_inst = ReconstructionLoss(
|
||||||
|
nn.BCEWithLogitsLoss(reduction="none")
|
||||||
|
).to(device, dtype)
|
||||||
|
|
||||||
|
loss_reg_inst = RegularizationLoss(
|
||||||
|
lambda_p=args.lambda_p,
|
||||||
|
max_steps=args.max_steps,
|
||||||
|
).to(device, dtype)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
module.parameters(),
|
||||||
|
lr=0.0003,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training and evaluation loops
|
||||||
|
iterator = tqdm(enumerate(dataloader_train), total=args.n_iter)
|
||||||
|
for step, (x_batch, y_true_batch) in iterator:
|
||||||
|
x_batch = x_batch.to(device, dtype)
|
||||||
|
y_true_batch = y_true_batch.to(device, dtype)
|
||||||
|
|
||||||
|
y_pred_batch, p, halting_step = module(x_batch)
|
||||||
|
|
||||||
|
loss_rec = loss_rec_inst(
|
||||||
|
p,
|
||||||
|
y_pred_batch,
|
||||||
|
y_true_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_reg = loss_reg_inst(
|
||||||
|
p,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_overall = loss_rec + args.beta * loss_reg
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss_overall.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
writer.add_scalar("loss_rec", loss_rec, step)
|
||||||
|
writer.add_scalar("loss_reg", loss_reg, step)
|
||||||
|
writer.add_scalar("loss_overall", loss_overall, step)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if step % args.eval_frequency == 0:
|
||||||
|
module.eval()
|
||||||
|
|
||||||
|
for dataloader_name, dataloader in eval_dataloaders.items():
|
||||||
|
metrics_single, metrics_per_step = evaluate(
|
||||||
|
dataloader,
|
||||||
|
module,
|
||||||
|
)
|
||||||
|
fig_dist = plot_distributions(
|
||||||
|
loss_reg_inst.p_g.cpu().numpy(),
|
||||||
|
metrics_per_step["p"],
|
||||||
|
)
|
||||||
|
writer.add_figure(
|
||||||
|
f"distributions/{dataloader_name}", fig_dist, step
|
||||||
|
)
|
||||||
|
|
||||||
|
fig_acc = plot_accuracy(metrics_per_step["accuracy"])
|
||||||
|
writer.add_figure(
|
||||||
|
f"accuracy_per_step/{dataloader_name}", fig_acc, step
|
||||||
|
)
|
||||||
|
|
||||||
|
for metric_name, metric_value in metrics_single.items():
|
||||||
|
writer.add_scalar(
|
||||||
|
f"{metric_name}/{dataloader_name}",
|
||||||
|
metric_value,
|
||||||
|
step,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.save(module, log_folder / "checkpoint.pth")
|
||||||
|
|
||||||
|
module.train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
289
utils.py
Normal file
289
utils.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from sparselinear import *
|
||||||
|
|
||||||
|
class ParityDataset(Dataset):
|
||||||
|
"""Parity of vectors - binary classification dataset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_samples : int
|
||||||
|
Number of samples to generate.
|
||||||
|
|
||||||
|
n_elems : int
|
||||||
|
Size of the vectors.
|
||||||
|
|
||||||
|
n_nonzero_min, n_nonzero_max : int or None
|
||||||
|
Minimum (inclusive) and maximum (inclusive) number of nonzero
|
||||||
|
elements in the feature vector. If not specified then `(1, n_elem)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_samples,
|
||||||
|
n_elems,
|
||||||
|
n_nonzero_min=None,
|
||||||
|
n_nonzero_max=None,
|
||||||
|
):
|
||||||
|
self.n_samples = n_samples
|
||||||
|
self.n_elems = n_elems
|
||||||
|
|
||||||
|
self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min
|
||||||
|
self.n_nonzero_max = (
|
||||||
|
n_elems if n_nonzero_max is None else n_nonzero_max
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Get the number of samples."""
|
||||||
|
return self.n_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""Get a feature vector and it's parity (target).
|
||||||
|
|
||||||
|
Note that the generating process is random.
|
||||||
|
"""
|
||||||
|
x = torch.zeros((self.n_elems,))
|
||||||
|
n_non_zero = torch.randint(
|
||||||
|
self.n_nonzero_min, self.n_nonzero_max + 1, (1,)
|
||||||
|
).item()
|
||||||
|
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
|
||||||
|
x = x[torch.randperm(self.n_elems)]
|
||||||
|
|
||||||
|
y = (x == 1.0).sum() % 2
|
||||||
|
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class PetriDishNet(nn.Module):
|
||||||
|
"""Network based on the PetriDish-Architecture
|
||||||
|
|
||||||
|
Is capable of performing an uncontrolled intelligence-explosion.
|
||||||
|
May or may not lead to the technological singularity when executed.
|
||||||
|
Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_in : int
|
||||||
|
Number of features in the input-vector.
|
||||||
|
|
||||||
|
n_out : int
|
||||||
|
Number of features in the output-vector.
|
||||||
|
|
||||||
|
n_neurons : int
|
||||||
|
Number of neurons.
|
||||||
|
|
||||||
|
max_steps : int
|
||||||
|
Maximum number of steps the network can "ponder" for.
|
||||||
|
|
||||||
|
allow_halting : bool
|
||||||
|
If True, then the forward pass is allowed to halt before
|
||||||
|
reaching the maximum steps.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
synapses : SparseLinear
|
||||||
|
Magic stuff
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')'
|
||||||
|
|
||||||
|
self.n_in = n_in
|
||||||
|
self.n_out = n_out
|
||||||
|
self.n_neurons = n_neurons
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self.allow_halting = allow_halting
|
||||||
|
|
||||||
|
self.synapses = SparseLinear(n_neurons, n_neurons,
|
||||||
|
sparsity=0.9,
|
||||||
|
small_world=True,
|
||||||
|
dynamic=True,
|
||||||
|
deltaT=6000,
|
||||||
|
Tend=150000,
|
||||||
|
alpha=0.1,
|
||||||
|
max_size=1e8)
|
||||||
|
|
||||||
|
self.ionDucts = ActivationSparsity(n_neurons,
|
||||||
|
alpha=0.1,
|
||||||
|
beta=1.5,
|
||||||
|
act_sparsity=0.65)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Run forward pass.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Batch of input features of shape `(batch_size, n_elems)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
y : torch.Tensor
|
||||||
|
Tensor of shape `(max_steps, batch_size, n_out)` representing
|
||||||
|
the predictions for each step and each sample. In case
|
||||||
|
`allow_halting=True` then the shape is
|
||||||
|
`(steps, batch_size, n_out)` where `1 <= steps <= max_steps`.
|
||||||
|
|
||||||
|
p : torch.Tensor
|
||||||
|
Tensor of shape `(max_steps, batch_size)` representing
|
||||||
|
the halting probabilities. Sums over rows (fixing a sample)
|
||||||
|
are 1. In case `allow_halting=True` then the shape is
|
||||||
|
`(steps, batch_size)` where `1 <= steps <= max_steps`.
|
||||||
|
|
||||||
|
halting_step : torch.Tensor
|
||||||
|
An integer for each sample in the batch that corresponds to
|
||||||
|
the step when it was halted. The shape is `(batch_size,)`. The
|
||||||
|
minimal value is 1 because we always run at least one step.
|
||||||
|
"""
|
||||||
|
batch_size, _ = x.shape
|
||||||
|
device = x.device
|
||||||
|
|
||||||
|
state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0)
|
||||||
|
|
||||||
|
un_halted_prob = x.new_ones(batch_size)
|
||||||
|
|
||||||
|
y_list = []
|
||||||
|
p_list = []
|
||||||
|
|
||||||
|
halting_step = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
for n in range(1, self.max_steps + 1):
|
||||||
|
state = self.synapses(state)
|
||||||
|
if n == self.max_steps:
|
||||||
|
lambda_n = x.new_ones(batch_size) # (batch_size,)
|
||||||
|
else:
|
||||||
|
lambda_n = torch.sigmoid(state[:, 0])
|
||||||
|
|
||||||
|
# Store releavant outputs
|
||||||
|
y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
|
||||||
|
p_list.append(un_halted_prob * lambda_n)
|
||||||
|
|
||||||
|
halting_step = torch.maximum(
|
||||||
|
n
|
||||||
|
* (halting_step == 0)
|
||||||
|
* torch.bernoulli(lambda_n).to(torch.long),
|
||||||
|
halting_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare for next iteration
|
||||||
|
un_halted_prob = un_halted_prob * (1 - lambda_n)
|
||||||
|
|
||||||
|
# Potentially stop if all samples halted
|
||||||
|
#if self.allow_halting and (halting_step > 0).sum() == batch_size:
|
||||||
|
if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
y = torch.stack(y_list)
|
||||||
|
p = torch.stack(p_list)
|
||||||
|
|
||||||
|
return y, p, halting_step
|
||||||
|
|
||||||
|
|
||||||
|
class ReconstructionLoss(nn.Module):
|
||||||
|
"""Weighted average of per step losses.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
loss_func : callable
|
||||||
|
Loss function that accepts `y_pred` and `y_true` as arguments. Both
|
||||||
|
of these tensors have shape `(batch_size, n_out)`. It outputs a loss for
|
||||||
|
each sample in the batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_func):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.loss_func = loss_func
|
||||||
|
|
||||||
|
def forward(self, p, y_pred, y_true):
|
||||||
|
"""Compute loss.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
p : torch.Tensor
|
||||||
|
Probability of halting of shape `(max_steps, batch_size)`.
|
||||||
|
|
||||||
|
y_pred : torch.Tensor
|
||||||
|
Predicted outputs of shape `(max_steps, batch_size, n_out)`.
|
||||||
|
|
||||||
|
y_true : torch.Tensor
|
||||||
|
True targets of shape `(batch_size, n_out)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
loss : torch.Tensor
|
||||||
|
Scalar representing the reconstruction loss. It is nothing else
|
||||||
|
than a weighted sum of per step losses.
|
||||||
|
"""
|
||||||
|
max_steps, _ = p.shape
|
||||||
|
total_loss = p.new_tensor(0.0)
|
||||||
|
|
||||||
|
for n in range(max_steps):
|
||||||
|
loss_per_sample = p[n] * self.loss_func(
|
||||||
|
y_pred[n], y_true
|
||||||
|
) # (batch_size,)
|
||||||
|
total_loss = total_loss + loss_per_sample.mean() # (1,)
|
||||||
|
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
||||||
|
class RegularizationLoss(nn.Module):
|
||||||
|
"""Enforce halting distribution to ressemble the geometric distribution.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
lambda_p : float
|
||||||
|
The single parameter determining uniquely the geometric distribution.
|
||||||
|
Note that the expected value of this distribution is going to be
|
||||||
|
`1 / lambda_p`.
|
||||||
|
|
||||||
|
max_steps : int
|
||||||
|
Maximum number of pondering steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lambda_p, max_steps=20):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
p_g = torch.zeros((max_steps,))
|
||||||
|
not_halted = 1.0
|
||||||
|
|
||||||
|
for k in range(max_steps):
|
||||||
|
p_g[k] = not_halted * lambda_p
|
||||||
|
not_halted = not_halted * (1 - lambda_p)
|
||||||
|
|
||||||
|
self.register_buffer("p_g", p_g)
|
||||||
|
self.kl_div = nn.KLDivLoss(reduction="batchmean")
|
||||||
|
|
||||||
|
def forward(self, p):
|
||||||
|
"""Compute loss.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
p : torch.Tensor
|
||||||
|
Probability of halting of shape `(steps, batch_size)`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
loss : torch.Tensor
|
||||||
|
Scalar representing the regularization loss.
|
||||||
|
"""
|
||||||
|
steps, batch_size = p.shape
|
||||||
|
|
||||||
|
p = p.transpose(0, 1) # (batch_size, max_steps)
|
||||||
|
|
||||||
|
p_g_batch = self.p_g[None, :steps].expand_as(
|
||||||
|
p
|
||||||
|
) # (batch_size, max_steps)
|
||||||
|
|
||||||
|
return self.kl_div(p.log(), p_g_batch)
|
Loading…
Reference in New Issue
Block a user