initial commit
This commit is contained in:
commit
63cd063c70
12
README.md
Normal file
12
README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Project PetriDish
|
||||
Quick and dirty PoC for the idea behind Project Neuromorph.
|
||||
Combines PonderNet and SparseLinear.
|
||||
PonderNet stolen from https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/pondernet
|
||||
SparseLinear stolen from https://pypi.org/project/sparselinear/
|
||||
|
||||
## Architecture
|
||||
We a neural network comprised of a set of neurons that are connected using a set of synapses. Neurons that are close to each other (es use a 1D-Distance metric (Project Neuromorph will have approximate euclidean distance)) have a higher chance to have a synaptic connection.
|
||||
We train this net like normal; but we also allow the structure of the synapctic connections to change during training. (Number of neurons remains constant; this is also variable in Neuromorph; Neuromorph will also use more advanced algorithms to decide where to spawn new synapses)
|
||||
In every firing-cycle only a fraction of all neurons are allowed to fire (highest output) all others are inhibited. (In Project Neuromorph this will be less strict; low firing-rates will have higher dropout-chances and we discurage firing thought an additional loss)
|
||||
Based on the PonderNet-Architecture we allow our network to 'think' as long as it wants about a given problem (well, ok; there is a maximum amount of firing-cycles to make training possbile)
|
||||
The input's and output's of the network have a fixed length. (Project Neuromorph will also allow variable-length outputs like in a RNN)
|
BIN
__pycache__/utils.cpython-39.pyc
Normal file
BIN
__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
398
train.py
Normal file
398
train.py
Normal file
@ -0,0 +1,398 @@
|
||||
from argparse import ArgumentParser
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ParityDataset,
|
||||
PetriDishNet,
|
||||
ReconstructionLoss,
|
||||
RegularizationLoss,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(dataloader, module):
|
||||
"""Compute relevant metrics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataloader : DataLoader
|
||||
Dataloader that yields batches of `x` and `y`.
|
||||
|
||||
module : PonderNet
|
||||
Our pondering network.
|
||||
|
||||
Returns
|
||||
-------
|
||||
metrics_single : dict
|
||||
Scalar metrics. The keys are names and the values are `torch.Tensor`.
|
||||
These metrics are computed as mean values over the entire dataset.
|
||||
|
||||
metrics_per_step : dict
|
||||
Per step metrics. The keys are names and the values are `torch.Tensor`
|
||||
of shape `(max_steps,)`. These metrics are computed as mean values over
|
||||
the entire dataset.
|
||||
|
||||
"""
|
||||
# Imply device and dtype
|
||||
param = next(module.parameters())
|
||||
device, dtype = param.device, param.dtype
|
||||
|
||||
metrics_single_ = {
|
||||
"accuracy_halted": [],
|
||||
"halting_step": [],
|
||||
}
|
||||
metrics_per_step_ = {
|
||||
"accuracy": [],
|
||||
"p": [],
|
||||
}
|
||||
|
||||
for x_batch, y_true_batch in dataloader:
|
||||
x_batch = x_batch.to(device, dtype) # (batch_size, n_elems)
|
||||
y_true_batch = y_true_batch.to(device, dtype) # (batch_size,)
|
||||
|
||||
y_pred_batch, p, halting_step = module(x_batch)
|
||||
y_halted_batch = y_pred_batch.gather(
|
||||
dim=0,
|
||||
index=halting_step[None, :] - 1,
|
||||
)[
|
||||
0
|
||||
] # (batch_size,)
|
||||
|
||||
# Computing single metrics (mean over samples in the batch)
|
||||
accuracy_halted = (
|
||||
((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean()
|
||||
)
|
||||
|
||||
metrics_single_["accuracy_halted"].append(accuracy_halted)
|
||||
metrics_single_["halting_step"].append(
|
||||
halting_step.to(torch.float).mean()
|
||||
)
|
||||
|
||||
# Computing per step metrics (mean over samples in the batch)
|
||||
accuracy = (
|
||||
((y_pred_batch > 0) == y_true_batch[None, :])
|
||||
.to(torch.float32)
|
||||
.mean(dim=1)
|
||||
)
|
||||
|
||||
metrics_per_step_["accuracy"].append(accuracy)
|
||||
metrics_per_step_["p"].append(p.mean(dim=1))
|
||||
|
||||
metrics_single = {
|
||||
name: torch.stack(values).mean(dim=0).cpu().numpy()
|
||||
for name, values in metrics_single_.items()
|
||||
}
|
||||
|
||||
metrics_per_step = {
|
||||
name: torch.stack(values).mean(dim=0).cpu().numpy()
|
||||
for name, values in metrics_per_step_.items()
|
||||
}
|
||||
|
||||
return metrics_single, metrics_per_step
|
||||
|
||||
|
||||
def plot_distributions(target, predicted):
|
||||
"""Create a barplot.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target, predicted : np.ndarray
|
||||
Arrays of shape `(max_steps,)` representing the target and predicted
|
||||
probability distributions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.Figure
|
||||
"""
|
||||
support = list(range(1, len(target) + 1))
|
||||
|
||||
fig, ax = plt.subplots(dpi=140)
|
||||
|
||||
ax.bar(
|
||||
support,
|
||||
target,
|
||||
color="red",
|
||||
label=f"Target - Geometric({target[0].item():.2f})",
|
||||
)
|
||||
|
||||
ax.bar(
|
||||
support,
|
||||
predicted,
|
||||
color="green",
|
||||
width=0.4,
|
||||
label="Predicted",
|
||||
)
|
||||
|
||||
ax.set_ylim(0, 0.6)
|
||||
ax.set_xticks(support)
|
||||
ax.legend()
|
||||
ax.grid()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_accuracy(accuracy):
|
||||
"""Create a barplot representing accuracy over different halting steps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
accuracy : np.array
|
||||
1D array representing accuracy if we were to take the output after
|
||||
the corresponding step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.Figure
|
||||
"""
|
||||
support = list(range(1, len(accuracy) + 1))
|
||||
|
||||
fig, ax = plt.subplots(dpi=140)
|
||||
|
||||
ax.bar(
|
||||
support,
|
||||
accuracy,
|
||||
label="Accuracy over different steps",
|
||||
)
|
||||
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_xticks(support)
|
||||
ax.legend()
|
||||
ax.grid()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
"""CLI for training."""
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"log_folder",
|
||||
type=str,
|
||||
help="Folder where tensorboard logging is saved",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Regularization loss coefficient",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
choices={"cpu", "cuda"},
|
||||
default="cpu",
|
||||
help="Device to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-frequency",
|
||||
type=int,
|
||||
default=10_000,
|
||||
help="Evaluation is run every `eval_frequency` steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lambda-p",
|
||||
type=float,
|
||||
default=0.4,
|
||||
help="True probability of success for a geometric distribution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-iter",
|
||||
type=int,
|
||||
default=1_000_000,
|
||||
help="Number of gradient steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-elems",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of elements",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-hidden",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of hidden elements in the reccurent cell",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-nonzero",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=(None, None),
|
||||
help="Lower and upper bound on nonzero elements in the training set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-steps",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Maximum number of pondering steps",
|
||||
)
|
||||
|
||||
# Parameters
|
||||
args = parser.parse_args(argv)
|
||||
print(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
dtype = torch.float32
|
||||
n_eval_samples = 1000
|
||||
batch_size_eval = 50
|
||||
|
||||
if args.n_nonzero[0] is None and args.n_nonzero[1] is None:
|
||||
threshold = int(0.3 * args.n_elems)
|
||||
range_nonzero_easy = (1, threshold)
|
||||
range_nonzero_hard = (args.n_elems - threshold, args.n_elems)
|
||||
else:
|
||||
range_nonzero_easy = (1, args.n_nonzero[1])
|
||||
range_nonzero_hard = (args.n_nonzero[1] + 1, args.n_elems)
|
||||
|
||||
# Tensorboard
|
||||
log_folder = pathlib.Path(args.log_folder)
|
||||
writer = SummaryWriter(log_folder)
|
||||
writer.add_text("parameters", json.dumps(vars(args)))
|
||||
|
||||
# Prepare data
|
||||
dataloader_train = DataLoader(
|
||||
ParityDataset(
|
||||
n_samples=args.batch_size * args.n_iter,
|
||||
n_elems=args.n_elems,
|
||||
n_nonzero_min=args.n_nonzero[0],
|
||||
n_nonzero_max=args.n_nonzero[1],
|
||||
),
|
||||
batch_size=args.batch_size,
|
||||
) # consider specifying `num_workers` for speedups
|
||||
eval_dataloaders = {
|
||||
"test": DataLoader(
|
||||
ParityDataset(
|
||||
n_samples=n_eval_samples,
|
||||
n_elems=args.n_elems,
|
||||
n_nonzero_min=args.n_nonzero[0],
|
||||
n_nonzero_max=args.n_nonzero[1],
|
||||
),
|
||||
batch_size=batch_size_eval,
|
||||
),
|
||||
f"{range_nonzero_easy[0]}_{range_nonzero_easy[1]}": DataLoader(
|
||||
ParityDataset(
|
||||
n_samples=n_eval_samples,
|
||||
n_elems=args.n_elems,
|
||||
n_nonzero_min=range_nonzero_easy[0],
|
||||
n_nonzero_max=range_nonzero_easy[1],
|
||||
),
|
||||
batch_size=batch_size_eval,
|
||||
),
|
||||
f"{range_nonzero_hard[0]}_{range_nonzero_hard[1]}": DataLoader(
|
||||
ParityDataset(
|
||||
n_samples=n_eval_samples,
|
||||
n_elems=args.n_elems,
|
||||
n_nonzero_min=range_nonzero_hard[0],
|
||||
n_nonzero_max=range_nonzero_hard[1],
|
||||
),
|
||||
batch_size=batch_size_eval,
|
||||
),
|
||||
}
|
||||
|
||||
# Model preparation
|
||||
module = PetriDashNet(
|
||||
n_in=args.n_elems,
|
||||
n_neurons=args.n_hidden,
|
||||
n_out=1,
|
||||
max_steps=args.max_steps,
|
||||
)
|
||||
module = module.to(device, dtype)
|
||||
|
||||
# Loss preparation
|
||||
loss_rec_inst = ReconstructionLoss(
|
||||
nn.BCEWithLogitsLoss(reduction="none")
|
||||
).to(device, dtype)
|
||||
|
||||
loss_reg_inst = RegularizationLoss(
|
||||
lambda_p=args.lambda_p,
|
||||
max_steps=args.max_steps,
|
||||
).to(device, dtype)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
module.parameters(),
|
||||
lr=0.0003,
|
||||
)
|
||||
|
||||
# Training and evaluation loops
|
||||
iterator = tqdm(enumerate(dataloader_train), total=args.n_iter)
|
||||
for step, (x_batch, y_true_batch) in iterator:
|
||||
x_batch = x_batch.to(device, dtype)
|
||||
y_true_batch = y_true_batch.to(device, dtype)
|
||||
|
||||
y_pred_batch, p, halting_step = module(x_batch)
|
||||
|
||||
loss_rec = loss_rec_inst(
|
||||
p,
|
||||
y_pred_batch,
|
||||
y_true_batch,
|
||||
)
|
||||
|
||||
loss_reg = loss_reg_inst(
|
||||
p,
|
||||
)
|
||||
|
||||
loss_overall = loss_rec + args.beta * loss_reg
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss_overall.backward()
|
||||
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
|
||||
optimizer.step()
|
||||
|
||||
# Logging
|
||||
writer.add_scalar("loss_rec", loss_rec, step)
|
||||
writer.add_scalar("loss_reg", loss_reg, step)
|
||||
writer.add_scalar("loss_overall", loss_overall, step)
|
||||
|
||||
# Evaluation
|
||||
if step % args.eval_frequency == 0:
|
||||
module.eval()
|
||||
|
||||
for dataloader_name, dataloader in eval_dataloaders.items():
|
||||
metrics_single, metrics_per_step = evaluate(
|
||||
dataloader,
|
||||
module,
|
||||
)
|
||||
fig_dist = plot_distributions(
|
||||
loss_reg_inst.p_g.cpu().numpy(),
|
||||
metrics_per_step["p"],
|
||||
)
|
||||
writer.add_figure(
|
||||
f"distributions/{dataloader_name}", fig_dist, step
|
||||
)
|
||||
|
||||
fig_acc = plot_accuracy(metrics_per_step["accuracy"])
|
||||
writer.add_figure(
|
||||
f"accuracy_per_step/{dataloader_name}", fig_acc, step
|
||||
)
|
||||
|
||||
for metric_name, metric_value in metrics_single.items():
|
||||
writer.add_scalar(
|
||||
f"{metric_name}/{dataloader_name}",
|
||||
metric_value,
|
||||
step,
|
||||
)
|
||||
|
||||
torch.save(module, log_folder / "checkpoint.pth")
|
||||
|
||||
module.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
289
utils.py
Normal file
289
utils.py
Normal file
@ -0,0 +1,289 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
from sparselinear import *
|
||||
|
||||
class ParityDataset(Dataset):
|
||||
"""Parity of vectors - binary classification dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_samples : int
|
||||
Number of samples to generate.
|
||||
|
||||
n_elems : int
|
||||
Size of the vectors.
|
||||
|
||||
n_nonzero_min, n_nonzero_max : int or None
|
||||
Minimum (inclusive) and maximum (inclusive) number of nonzero
|
||||
elements in the feature vector. If not specified then `(1, n_elem)`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples,
|
||||
n_elems,
|
||||
n_nonzero_min=None,
|
||||
n_nonzero_max=None,
|
||||
):
|
||||
self.n_samples = n_samples
|
||||
self.n_elems = n_elems
|
||||
|
||||
self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min
|
||||
self.n_nonzero_max = (
|
||||
n_elems if n_nonzero_max is None else n_nonzero_max
|
||||
)
|
||||
|
||||
assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of samples."""
|
||||
return self.n_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get a feature vector and it's parity (target).
|
||||
|
||||
Note that the generating process is random.
|
||||
"""
|
||||
x = torch.zeros((self.n_elems,))
|
||||
n_non_zero = torch.randint(
|
||||
self.n_nonzero_min, self.n_nonzero_max + 1, (1,)
|
||||
).item()
|
||||
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
|
||||
x = x[torch.randperm(self.n_elems)]
|
||||
|
||||
y = (x == 1.0).sum() % 2
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
class PetriDishNet(nn.Module):
|
||||
"""Network based on the PetriDish-Architecture
|
||||
|
||||
Is capable of performing an uncontrolled intelligence-explosion.
|
||||
May or may not lead to the technological singularity when executed.
|
||||
Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_in : int
|
||||
Number of features in the input-vector.
|
||||
|
||||
n_out : int
|
||||
Number of features in the output-vector.
|
||||
|
||||
n_neurons : int
|
||||
Number of neurons.
|
||||
|
||||
max_steps : int
|
||||
Maximum number of steps the network can "ponder" for.
|
||||
|
||||
allow_halting : bool
|
||||
If True, then the forward pass is allowed to halt before
|
||||
reaching the maximum steps.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
synapses : SparseLinear
|
||||
Magic stuff
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')'
|
||||
|
||||
self.n_in = n_in
|
||||
self.n_out = n_out
|
||||
self.n_neurons = n_neurons
|
||||
self.max_steps = max_steps
|
||||
self.allow_halting = allow_halting
|
||||
|
||||
self.synapses = SparseLinear(n_neurons, n_neurons,
|
||||
sparsity=0.9,
|
||||
small_world=True,
|
||||
dynamic=True,
|
||||
deltaT=6000,
|
||||
Tend=150000,
|
||||
alpha=0.1,
|
||||
max_size=1e8)
|
||||
|
||||
self.ionDucts = ActivationSparsity(n_neurons,
|
||||
alpha=0.1,
|
||||
beta=1.5,
|
||||
act_sparsity=0.65)
|
||||
|
||||
def forward(self, x):
|
||||
"""Run forward pass.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Batch of input features of shape `(batch_size, n_elems)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y : torch.Tensor
|
||||
Tensor of shape `(max_steps, batch_size, n_out)` representing
|
||||
the predictions for each step and each sample. In case
|
||||
`allow_halting=True` then the shape is
|
||||
`(steps, batch_size, n_out)` where `1 <= steps <= max_steps`.
|
||||
|
||||
p : torch.Tensor
|
||||
Tensor of shape `(max_steps, batch_size)` representing
|
||||
the halting probabilities. Sums over rows (fixing a sample)
|
||||
are 1. In case `allow_halting=True` then the shape is
|
||||
`(steps, batch_size)` where `1 <= steps <= max_steps`.
|
||||
|
||||
halting_step : torch.Tensor
|
||||
An integer for each sample in the batch that corresponds to
|
||||
the step when it was halted. The shape is `(batch_size,)`. The
|
||||
minimal value is 1 because we always run at least one step.
|
||||
"""
|
||||
batch_size, _ = x.shape
|
||||
device = x.device
|
||||
|
||||
state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0)
|
||||
|
||||
un_halted_prob = x.new_ones(batch_size)
|
||||
|
||||
y_list = []
|
||||
p_list = []
|
||||
|
||||
halting_step = torch.zeros(
|
||||
batch_size,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for n in range(1, self.max_steps + 1):
|
||||
state = self.synapses(state)
|
||||
if n == self.max_steps:
|
||||
lambda_n = x.new_ones(batch_size) # (batch_size,)
|
||||
else:
|
||||
lambda_n = torch.sigmoid(state[:, 0])
|
||||
|
||||
# Store releavant outputs
|
||||
y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
|
||||
p_list.append(un_halted_prob * lambda_n)
|
||||
|
||||
halting_step = torch.maximum(
|
||||
n
|
||||
* (halting_step == 0)
|
||||
* torch.bernoulli(lambda_n).to(torch.long),
|
||||
halting_step,
|
||||
)
|
||||
|
||||
# Prepare for next iteration
|
||||
un_halted_prob = un_halted_prob * (1 - lambda_n)
|
||||
|
||||
# Potentially stop if all samples halted
|
||||
#if self.allow_halting and (halting_step > 0).sum() == batch_size:
|
||||
if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size:
|
||||
break
|
||||
|
||||
y = torch.stack(y_list)
|
||||
p = torch.stack(p_list)
|
||||
|
||||
return y, p, halting_step
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
"""Weighted average of per step losses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
loss_func : callable
|
||||
Loss function that accepts `y_pred` and `y_true` as arguments. Both
|
||||
of these tensors have shape `(batch_size, n_out)`. It outputs a loss for
|
||||
each sample in the batch.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_func):
|
||||
super().__init__()
|
||||
|
||||
self.loss_func = loss_func
|
||||
|
||||
def forward(self, p, y_pred, y_true):
|
||||
"""Compute loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p : torch.Tensor
|
||||
Probability of halting of shape `(max_steps, batch_size)`.
|
||||
|
||||
y_pred : torch.Tensor
|
||||
Predicted outputs of shape `(max_steps, batch_size, n_out)`.
|
||||
|
||||
y_true : torch.Tensor
|
||||
True targets of shape `(batch_size, n_out)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
loss : torch.Tensor
|
||||
Scalar representing the reconstruction loss. It is nothing else
|
||||
than a weighted sum of per step losses.
|
||||
"""
|
||||
max_steps, _ = p.shape
|
||||
total_loss = p.new_tensor(0.0)
|
||||
|
||||
for n in range(max_steps):
|
||||
loss_per_sample = p[n] * self.loss_func(
|
||||
y_pred[n], y_true
|
||||
) # (batch_size,)
|
||||
total_loss = total_loss + loss_per_sample.mean() # (1,)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RegularizationLoss(nn.Module):
|
||||
"""Enforce halting distribution to ressemble the geometric distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lambda_p : float
|
||||
The single parameter determining uniquely the geometric distribution.
|
||||
Note that the expected value of this distribution is going to be
|
||||
`1 / lambda_p`.
|
||||
|
||||
max_steps : int
|
||||
Maximum number of pondering steps.
|
||||
"""
|
||||
|
||||
def __init__(self, lambda_p, max_steps=20):
|
||||
super().__init__()
|
||||
|
||||
p_g = torch.zeros((max_steps,))
|
||||
not_halted = 1.0
|
||||
|
||||
for k in range(max_steps):
|
||||
p_g[k] = not_halted * lambda_p
|
||||
not_halted = not_halted * (1 - lambda_p)
|
||||
|
||||
self.register_buffer("p_g", p_g)
|
||||
self.kl_div = nn.KLDivLoss(reduction="batchmean")
|
||||
|
||||
def forward(self, p):
|
||||
"""Compute loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
p : torch.Tensor
|
||||
Probability of halting of shape `(steps, batch_size)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
loss : torch.Tensor
|
||||
Scalar representing the regularization loss.
|
||||
"""
|
||||
steps, batch_size = p.shape
|
||||
|
||||
p = p.transpose(0, 1) # (batch_size, max_steps)
|
||||
|
||||
p_g_batch = self.p_g[None, :steps].expand_as(
|
||||
p
|
||||
) # (batch_size, max_steps)
|
||||
|
||||
return self.kl_div(p.log(), p_g_batch)
|
Loading…
Reference in New Issue
Block a user