petriDish/train.py

399 lines
10 KiB
Python

from argparse import ArgumentParser
import json
import pathlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import (
ParityDataset,
PetriDishNet,
ReconstructionLoss,
RegularizationLoss,
)
@torch.no_grad()
def evaluate(dataloader, module):
"""Compute relevant metrics.
Parameters
----------
dataloader : DataLoader
Dataloader that yields batches of `x` and `y`.
module : PonderNet
Our pondering network.
Returns
-------
metrics_single : dict
Scalar metrics. The keys are names and the values are `torch.Tensor`.
These metrics are computed as mean values over the entire dataset.
metrics_per_step : dict
Per step metrics. The keys are names and the values are `torch.Tensor`
of shape `(max_steps,)`. These metrics are computed as mean values over
the entire dataset.
"""
# Imply device and dtype
param = next(module.parameters())
device, dtype = param.device, param.dtype
metrics_single_ = {
"accuracy_halted": [],
"halting_step": [],
}
metrics_per_step_ = {
"accuracy": [],
"p": [],
}
for x_batch, y_true_batch in dataloader:
x_batch = x_batch.to(device, dtype) # (batch_size, n_elems)
y_true_batch = y_true_batch.to(device, dtype) # (batch_size,)
y_pred_batch, p, halting_step = module(x_batch)
y_halted_batch = y_pred_batch.gather(
dim=0,
index=halting_step[None, :] - 1,
)[
0
] # (batch_size,)
# Computing single metrics (mean over samples in the batch)
accuracy_halted = (
((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean()
)
metrics_single_["accuracy_halted"].append(accuracy_halted)
metrics_single_["halting_step"].append(
halting_step.to(torch.float).mean()
)
# Computing per step metrics (mean over samples in the batch)
accuracy = (
((y_pred_batch > 0) == y_true_batch[None, :])
.to(torch.float32)
.mean(dim=1)
)
metrics_per_step_["accuracy"].append(accuracy)
metrics_per_step_["p"].append(p.mean(dim=1))
metrics_single = {
name: torch.stack(values).mean(dim=0).cpu().numpy()
for name, values in metrics_single_.items()
}
metrics_per_step = {
name: torch.stack(values).mean(dim=0).cpu().numpy()
for name, values in metrics_per_step_.items()
}
return metrics_single, metrics_per_step
def plot_distributions(target, predicted):
"""Create a barplot.
Parameters
----------
target, predicted : np.ndarray
Arrays of shape `(max_steps,)` representing the target and predicted
probability distributions.
Returns
-------
matplotlib.Figure
"""
support = list(range(1, len(target) + 1))
fig, ax = plt.subplots(dpi=140)
ax.bar(
support,
target,
color="red",
label=f"Target - Geometric({target[0].item():.2f})",
)
ax.bar(
support,
predicted,
color="green",
width=0.4,
label="Predicted",
)
ax.set_ylim(0, 0.6)
ax.set_xticks(support)
ax.legend()
ax.grid()
return fig
def plot_accuracy(accuracy):
"""Create a barplot representing accuracy over different halting steps.
Parameters
----------
accuracy : np.array
1D array representing accuracy if we were to take the output after
the corresponding step.
Returns
-------
matplotlib.Figure
"""
support = list(range(1, len(accuracy) + 1))
fig, ax = plt.subplots(dpi=140)
ax.bar(
support,
accuracy,
label="Accuracy over different steps",
)
ax.set_ylim(0, 1)
ax.set_xticks(support)
ax.legend()
ax.grid()
return fig
def main(argv=None):
"""CLI for training."""
parser = ArgumentParser()
parser.add_argument(
"log_folder",
type=str,
help="Folder where tensorboard logging is saved",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="Batch size",
)
parser.add_argument(
"--beta",
type=float,
default=0.01,
help="Regularization loss coefficient",
)
parser.add_argument(
"-d",
"--device",
type=str,
choices={"cpu", "cuda"},
default="cpu",
help="Device to use",
)
parser.add_argument(
"--eval-frequency",
type=int,
default=10_000,
help="Evaluation is run every `eval_frequency` steps",
)
parser.add_argument(
"--lambda-p",
type=float,
default=0.4,
help="True probability of success for a geometric distribution",
)
parser.add_argument(
"--n-iter",
type=int,
default=1_000_000,
help="Number of gradient steps",
)
parser.add_argument(
"--n-elems",
type=int,
default=64,
help="Number of elements",
)
parser.add_argument(
"--n-hidden",
type=int,
default=64,
help="Number of hidden elements in the reccurent cell",
)
parser.add_argument(
"--n-nonzero",
type=int,
nargs=2,
default=(None, None),
help="Lower and upper bound on nonzero elements in the training set",
)
parser.add_argument(
"--max-steps",
type=int,
default=20,
help="Maximum number of pondering steps",
)
# Parameters
args = parser.parse_args(argv)
print(args)
device = torch.device(args.device)
dtype = torch.float32
n_eval_samples = 1000
batch_size_eval = 50
if args.n_nonzero[0] is None and args.n_nonzero[1] is None:
threshold = int(0.3 * args.n_elems)
range_nonzero_easy = (1, threshold)
range_nonzero_hard = (args.n_elems - threshold, args.n_elems)
else:
range_nonzero_easy = (1, args.n_nonzero[1])
range_nonzero_hard = (args.n_nonzero[1] + 1, args.n_elems)
# Tensorboard
log_folder = pathlib.Path(args.log_folder)
writer = SummaryWriter(log_folder)
writer.add_text("parameters", json.dumps(vars(args)))
# Prepare data
dataloader_train = DataLoader(
ParityDataset(
n_samples=args.batch_size * args.n_iter,
n_elems=args.n_elems,
n_nonzero_min=args.n_nonzero[0],
n_nonzero_max=args.n_nonzero[1],
),
batch_size=args.batch_size,
) # consider specifying `num_workers` for speedups
eval_dataloaders = {
"test": DataLoader(
ParityDataset(
n_samples=n_eval_samples,
n_elems=args.n_elems,
n_nonzero_min=args.n_nonzero[0],
n_nonzero_max=args.n_nonzero[1],
),
batch_size=batch_size_eval,
),
f"{range_nonzero_easy[0]}_{range_nonzero_easy[1]}": DataLoader(
ParityDataset(
n_samples=n_eval_samples,
n_elems=args.n_elems,
n_nonzero_min=range_nonzero_easy[0],
n_nonzero_max=range_nonzero_easy[1],
),
batch_size=batch_size_eval,
),
f"{range_nonzero_hard[0]}_{range_nonzero_hard[1]}": DataLoader(
ParityDataset(
n_samples=n_eval_samples,
n_elems=args.n_elems,
n_nonzero_min=range_nonzero_hard[0],
n_nonzero_max=range_nonzero_hard[1],
),
batch_size=batch_size_eval,
),
}
# Model preparation
module = PetriDishNet(
n_in=args.n_elems,
n_neurons=args.n_hidden,
n_out=1,
max_steps=args.max_steps,
)
module = module.to(device, dtype)
# Loss preparation
loss_rec_inst = ReconstructionLoss(
nn.BCEWithLogitsLoss(reduction="none")
).to(device, dtype)
loss_reg_inst = RegularizationLoss(
lambda_p=args.lambda_p,
max_steps=args.max_steps,
).to(device, dtype)
# Optimizer
optimizer = torch.optim.Adam(
module.parameters(),
lr=0.0003,
)
# Training and evaluation loops
iterator = tqdm(enumerate(dataloader_train), total=args.n_iter)
for step, (x_batch, y_true_batch) in iterator:
x_batch = x_batch.to(device, dtype)
y_true_batch = y_true_batch.to(device, dtype)
y_pred_batch, p, halting_step = module(x_batch)
loss_rec = loss_rec_inst(
p,
y_pred_batch,
y_true_batch,
)
loss_reg = loss_reg_inst(
p,
)
loss_overall = loss_rec + args.beta * loss_reg
optimizer.zero_grad()
loss_overall.backward()
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
optimizer.step()
# Logging
writer.add_scalar("loss_rec", loss_rec, step)
writer.add_scalar("loss_reg", loss_reg, step)
writer.add_scalar("loss_overall", loss_overall, step)
# Evaluation
if step % args.eval_frequency == 0:
module.eval()
for dataloader_name, dataloader in eval_dataloaders.items():
metrics_single, metrics_per_step = evaluate(
dataloader,
module,
)
fig_dist = plot_distributions(
loss_reg_inst.p_g.cpu().numpy(),
metrics_per_step["p"],
)
writer.add_figure(
f"distributions/{dataloader_name}", fig_dist, step
)
fig_acc = plot_accuracy(metrics_per_step["accuracy"])
writer.add_figure(
f"accuracy_per_step/{dataloader_name}", fig_acc, step
)
for metric_name, metric_value in metrics_single.items():
writer.add_scalar(
f"{metric_name}/{dataloader_name}",
metric_value,
step,
)
torch.save(module, log_folder / "checkpoint.pth")
module.train()
if __name__ == "__main__":
main()