commit 63cd063c7084b35bb120e0283765c6cbcd7a0c0a Author: Dominik Roth Date: Thu Oct 14 22:09:14 2021 +0200 initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..4f14281 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# Project PetriDish +Quick and dirty PoC for the idea behind Project Neuromorph. +Combines PonderNet and SparseLinear. +PonderNet stolen from https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/pondernet +SparseLinear stolen from https://pypi.org/project/sparselinear/ + +## Architecture +We a neural network comprised of a set of neurons that are connected using a set of synapses. Neurons that are close to each other (es use a 1D-Distance metric (Project Neuromorph will have approximate euclidean distance)) have a higher chance to have a synaptic connection. +We train this net like normal; but we also allow the structure of the synapctic connections to change during training. (Number of neurons remains constant; this is also variable in Neuromorph; Neuromorph will also use more advanced algorithms to decide where to spawn new synapses) +In every firing-cycle only a fraction of all neurons are allowed to fire (highest output) all others are inhibited. (In Project Neuromorph this will be less strict; low firing-rates will have higher dropout-chances and we discurage firing thought an additional loss) +Based on the PonderNet-Architecture we allow our network to 'think' as long as it wants about a given problem (well, ok; there is a maximum amount of firing-cycles to make training possbile) +The input's and output's of the network have a fixed length. (Project Neuromorph will also allow variable-length outputs like in a RNN) diff --git a/__pycache__/utils.cpython-39.pyc b/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..203978f Binary files /dev/null and b/__pycache__/utils.cpython-39.pyc differ diff --git a/results/experiment_b/31742/events.out.tfevents.1634241559.dodox-XPS-15.224268.0 b/results/experiment_b/31742/events.out.tfevents.1634241559.dodox-XPS-15.224268.0 new file mode 100644 index 0000000..9dd86e9 Binary files /dev/null and b/results/experiment_b/31742/events.out.tfevents.1634241559.dodox-XPS-15.224268.0 differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..5c41e90 --- /dev/null +++ b/train.py @@ -0,0 +1,398 @@ +from argparse import ArgumentParser +import json +import pathlib + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from utils import ( + ParityDataset, + PetriDishNet, + ReconstructionLoss, + RegularizationLoss, +) + + +@torch.no_grad() +def evaluate(dataloader, module): + """Compute relevant metrics. + + Parameters + ---------- + dataloader : DataLoader + Dataloader that yields batches of `x` and `y`. + + module : PonderNet + Our pondering network. + + Returns + ------- + metrics_single : dict + Scalar metrics. The keys are names and the values are `torch.Tensor`. + These metrics are computed as mean values over the entire dataset. + + metrics_per_step : dict + Per step metrics. The keys are names and the values are `torch.Tensor` + of shape `(max_steps,)`. These metrics are computed as mean values over + the entire dataset. + + """ + # Imply device and dtype + param = next(module.parameters()) + device, dtype = param.device, param.dtype + + metrics_single_ = { + "accuracy_halted": [], + "halting_step": [], + } + metrics_per_step_ = { + "accuracy": [], + "p": [], + } + + for x_batch, y_true_batch in dataloader: + x_batch = x_batch.to(device, dtype) # (batch_size, n_elems) + y_true_batch = y_true_batch.to(device, dtype) # (batch_size,) + + y_pred_batch, p, halting_step = module(x_batch) + y_halted_batch = y_pred_batch.gather( + dim=0, + index=halting_step[None, :] - 1, + )[ + 0 + ] # (batch_size,) + + # Computing single metrics (mean over samples in the batch) + accuracy_halted = ( + ((y_halted_batch > 0) == y_true_batch).to(torch.float32).mean() + ) + + metrics_single_["accuracy_halted"].append(accuracy_halted) + metrics_single_["halting_step"].append( + halting_step.to(torch.float).mean() + ) + + # Computing per step metrics (mean over samples in the batch) + accuracy = ( + ((y_pred_batch > 0) == y_true_batch[None, :]) + .to(torch.float32) + .mean(dim=1) + ) + + metrics_per_step_["accuracy"].append(accuracy) + metrics_per_step_["p"].append(p.mean(dim=1)) + + metrics_single = { + name: torch.stack(values).mean(dim=0).cpu().numpy() + for name, values in metrics_single_.items() + } + + metrics_per_step = { + name: torch.stack(values).mean(dim=0).cpu().numpy() + for name, values in metrics_per_step_.items() + } + + return metrics_single, metrics_per_step + + +def plot_distributions(target, predicted): + """Create a barplot. + + Parameters + ---------- + target, predicted : np.ndarray + Arrays of shape `(max_steps,)` representing the target and predicted + probability distributions. + + Returns + ------- + matplotlib.Figure + """ + support = list(range(1, len(target) + 1)) + + fig, ax = plt.subplots(dpi=140) + + ax.bar( + support, + target, + color="red", + label=f"Target - Geometric({target[0].item():.2f})", + ) + + ax.bar( + support, + predicted, + color="green", + width=0.4, + label="Predicted", + ) + + ax.set_ylim(0, 0.6) + ax.set_xticks(support) + ax.legend() + ax.grid() + + return fig + + +def plot_accuracy(accuracy): + """Create a barplot representing accuracy over different halting steps. + + Parameters + ---------- + accuracy : np.array + 1D array representing accuracy if we were to take the output after + the corresponding step. + + Returns + ------- + matplotlib.Figure + """ + support = list(range(1, len(accuracy) + 1)) + + fig, ax = plt.subplots(dpi=140) + + ax.bar( + support, + accuracy, + label="Accuracy over different steps", + ) + + ax.set_ylim(0, 1) + ax.set_xticks(support) + ax.legend() + ax.grid() + + return fig + + +def main(argv=None): + """CLI for training.""" + parser = ArgumentParser() + + parser.add_argument( + "log_folder", + type=str, + help="Folder where tensorboard logging is saved", + ) + parser.add_argument( + "--batch-size", + type=int, + default=128, + help="Batch size", + ) + parser.add_argument( + "--beta", + type=float, + default=0.01, + help="Regularization loss coefficient", + ) + parser.add_argument( + "-d", + "--device", + type=str, + choices={"cpu", "cuda"}, + default="cpu", + help="Device to use", + ) + parser.add_argument( + "--eval-frequency", + type=int, + default=10_000, + help="Evaluation is run every `eval_frequency` steps", + ) + parser.add_argument( + "--lambda-p", + type=float, + default=0.4, + help="True probability of success for a geometric distribution", + ) + parser.add_argument( + "--n-iter", + type=int, + default=1_000_000, + help="Number of gradient steps", + ) + parser.add_argument( + "--n-elems", + type=int, + default=64, + help="Number of elements", + ) + parser.add_argument( + "--n-hidden", + type=int, + default=64, + help="Number of hidden elements in the reccurent cell", + ) + parser.add_argument( + "--n-nonzero", + type=int, + nargs=2, + default=(None, None), + help="Lower and upper bound on nonzero elements in the training set", + ) + parser.add_argument( + "--max-steps", + type=int, + default=20, + help="Maximum number of pondering steps", + ) + + # Parameters + args = parser.parse_args(argv) + print(args) + + device = torch.device(args.device) + dtype = torch.float32 + n_eval_samples = 1000 + batch_size_eval = 50 + + if args.n_nonzero[0] is None and args.n_nonzero[1] is None: + threshold = int(0.3 * args.n_elems) + range_nonzero_easy = (1, threshold) + range_nonzero_hard = (args.n_elems - threshold, args.n_elems) + else: + range_nonzero_easy = (1, args.n_nonzero[1]) + range_nonzero_hard = (args.n_nonzero[1] + 1, args.n_elems) + + # Tensorboard + log_folder = pathlib.Path(args.log_folder) + writer = SummaryWriter(log_folder) + writer.add_text("parameters", json.dumps(vars(args))) + + # Prepare data + dataloader_train = DataLoader( + ParityDataset( + n_samples=args.batch_size * args.n_iter, + n_elems=args.n_elems, + n_nonzero_min=args.n_nonzero[0], + n_nonzero_max=args.n_nonzero[1], + ), + batch_size=args.batch_size, + ) # consider specifying `num_workers` for speedups + eval_dataloaders = { + "test": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=args.n_nonzero[0], + n_nonzero_max=args.n_nonzero[1], + ), + batch_size=batch_size_eval, + ), + f"{range_nonzero_easy[0]}_{range_nonzero_easy[1]}": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=range_nonzero_easy[0], + n_nonzero_max=range_nonzero_easy[1], + ), + batch_size=batch_size_eval, + ), + f"{range_nonzero_hard[0]}_{range_nonzero_hard[1]}": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=range_nonzero_hard[0], + n_nonzero_max=range_nonzero_hard[1], + ), + batch_size=batch_size_eval, + ), + } + + # Model preparation + module = PetriDashNet( + n_in=args.n_elems, + n_neurons=args.n_hidden, + n_out=1, + max_steps=args.max_steps, + ) + module = module.to(device, dtype) + + # Loss preparation + loss_rec_inst = ReconstructionLoss( + nn.BCEWithLogitsLoss(reduction="none") + ).to(device, dtype) + + loss_reg_inst = RegularizationLoss( + lambda_p=args.lambda_p, + max_steps=args.max_steps, + ).to(device, dtype) + + # Optimizer + optimizer = torch.optim.Adam( + module.parameters(), + lr=0.0003, + ) + + # Training and evaluation loops + iterator = tqdm(enumerate(dataloader_train), total=args.n_iter) + for step, (x_batch, y_true_batch) in iterator: + x_batch = x_batch.to(device, dtype) + y_true_batch = y_true_batch.to(device, dtype) + + y_pred_batch, p, halting_step = module(x_batch) + + loss_rec = loss_rec_inst( + p, + y_pred_batch, + y_true_batch, + ) + + loss_reg = loss_reg_inst( + p, + ) + + loss_overall = loss_rec + args.beta * loss_reg + + optimizer.zero_grad() + loss_overall.backward() + torch.nn.utils.clip_grad_norm_(module.parameters(), 1) + optimizer.step() + + # Logging + writer.add_scalar("loss_rec", loss_rec, step) + writer.add_scalar("loss_reg", loss_reg, step) + writer.add_scalar("loss_overall", loss_overall, step) + + # Evaluation + if step % args.eval_frequency == 0: + module.eval() + + for dataloader_name, dataloader in eval_dataloaders.items(): + metrics_single, metrics_per_step = evaluate( + dataloader, + module, + ) + fig_dist = plot_distributions( + loss_reg_inst.p_g.cpu().numpy(), + metrics_per_step["p"], + ) + writer.add_figure( + f"distributions/{dataloader_name}", fig_dist, step + ) + + fig_acc = plot_accuracy(metrics_per_step["accuracy"]) + writer.add_figure( + f"accuracy_per_step/{dataloader_name}", fig_acc, step + ) + + for metric_name, metric_value in metrics_single.items(): + writer.add_scalar( + f"{metric_name}/{dataloader_name}", + metric_value, + step, + ) + + torch.save(module, log_folder / "checkpoint.pth") + + module.train() + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e4fdf41 --- /dev/null +++ b/utils.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from sparselinear import * + +class ParityDataset(Dataset): + """Parity of vectors - binary classification dataset. + + Parameters + ---------- + n_samples : int + Number of samples to generate. + + n_elems : int + Size of the vectors. + + n_nonzero_min, n_nonzero_max : int or None + Minimum (inclusive) and maximum (inclusive) number of nonzero + elements in the feature vector. If not specified then `(1, n_elem)`. + """ + + def __init__( + self, + n_samples, + n_elems, + n_nonzero_min=None, + n_nonzero_max=None, + ): + self.n_samples = n_samples + self.n_elems = n_elems + + self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min + self.n_nonzero_max = ( + n_elems if n_nonzero_max is None else n_nonzero_max + ) + + assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems + + def __len__(self): + """Get the number of samples.""" + return self.n_samples + + def __getitem__(self, idx): + """Get a feature vector and it's parity (target). + + Note that the generating process is random. + """ + x = torch.zeros((self.n_elems,)) + n_non_zero = torch.randint( + self.n_nonzero_min, self.n_nonzero_max + 1, (1,) + ).item() + x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1 + x = x[torch.randperm(self.n_elems)] + + y = (x == 1.0).sum() % 2 + + return x, y + + +class PetriDishNet(nn.Module): + """Network based on the PetriDish-Architecture + + Is capable of performing an uncontrolled intelligence-explosion. + May or may not lead to the technological singularity when executed. + Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental. + + Parameters + ---------- + n_in : int + Number of features in the input-vector. + + n_out : int + Number of features in the output-vector. + + n_neurons : int + Number of neurons. + + max_steps : int + Maximum number of steps the network can "ponder" for. + + allow_halting : bool + If True, then the forward pass is allowed to halt before + reaching the maximum steps. + + Attributes + ---------- + synapses : SparseLinear + Magic stuff + """ + + def __init__( + self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False + ): + super().__init__() + + assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')' + + self.n_in = n_in + self.n_out = n_out + self.n_neurons = n_neurons + self.max_steps = max_steps + self.allow_halting = allow_halting + + self.synapses = SparseLinear(n_neurons, n_neurons, + sparsity=0.9, + small_world=True, + dynamic=True, + deltaT=6000, + Tend=150000, + alpha=0.1, + max_size=1e8) + + self.ionDucts = ActivationSparsity(n_neurons, + alpha=0.1, + beta=1.5, + act_sparsity=0.65) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Batch of input features of shape `(batch_size, n_elems)`. + + Returns + ------- + y : torch.Tensor + Tensor of shape `(max_steps, batch_size, n_out)` representing + the predictions for each step and each sample. In case + `allow_halting=True` then the shape is + `(steps, batch_size, n_out)` where `1 <= steps <= max_steps`. + + p : torch.Tensor + Tensor of shape `(max_steps, batch_size)` representing + the halting probabilities. Sums over rows (fixing a sample) + are 1. In case `allow_halting=True` then the shape is + `(steps, batch_size)` where `1 <= steps <= max_steps`. + + halting_step : torch.Tensor + An integer for each sample in the batch that corresponds to + the step when it was halted. The shape is `(batch_size,)`. The + minimal value is 1 because we always run at least one step. + """ + batch_size, _ = x.shape + device = x.device + + state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0) + + un_halted_prob = x.new_ones(batch_size) + + y_list = [] + p_list = [] + + halting_step = torch.zeros( + batch_size, + dtype=torch.long, + device=device, + ) + + for n in range(1, self.max_steps + 1): + state = self.synapses(state) + if n == self.max_steps: + lambda_n = x.new_ones(batch_size) # (batch_size,) + else: + lambda_n = torch.sigmoid(state[:, 0]) + + # Store releavant outputs + y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)]) + p_list.append(un_halted_prob * lambda_n) + + halting_step = torch.maximum( + n + * (halting_step == 0) + * torch.bernoulli(lambda_n).to(torch.long), + halting_step, + ) + + # Prepare for next iteration + un_halted_prob = un_halted_prob * (1 - lambda_n) + + # Potentially stop if all samples halted + #if self.allow_halting and (halting_step > 0).sum() == batch_size: + if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size: + break + + y = torch.stack(y_list) + p = torch.stack(p_list) + + return y, p, halting_step + + +class ReconstructionLoss(nn.Module): + """Weighted average of per step losses. + + Parameters + ---------- + loss_func : callable + Loss function that accepts `y_pred` and `y_true` as arguments. Both + of these tensors have shape `(batch_size, n_out)`. It outputs a loss for + each sample in the batch. + """ + + def __init__(self, loss_func): + super().__init__() + + self.loss_func = loss_func + + def forward(self, p, y_pred, y_true): + """Compute loss. + + Parameters + ---------- + p : torch.Tensor + Probability of halting of shape `(max_steps, batch_size)`. + + y_pred : torch.Tensor + Predicted outputs of shape `(max_steps, batch_size, n_out)`. + + y_true : torch.Tensor + True targets of shape `(batch_size, n_out)`. + + Returns + ------- + loss : torch.Tensor + Scalar representing the reconstruction loss. It is nothing else + than a weighted sum of per step losses. + """ + max_steps, _ = p.shape + total_loss = p.new_tensor(0.0) + + for n in range(max_steps): + loss_per_sample = p[n] * self.loss_func( + y_pred[n], y_true + ) # (batch_size,) + total_loss = total_loss + loss_per_sample.mean() # (1,) + + return total_loss + + +class RegularizationLoss(nn.Module): + """Enforce halting distribution to ressemble the geometric distribution. + + Parameters + ---------- + lambda_p : float + The single parameter determining uniquely the geometric distribution. + Note that the expected value of this distribution is going to be + `1 / lambda_p`. + + max_steps : int + Maximum number of pondering steps. + """ + + def __init__(self, lambda_p, max_steps=20): + super().__init__() + + p_g = torch.zeros((max_steps,)) + not_halted = 1.0 + + for k in range(max_steps): + p_g[k] = not_halted * lambda_p + not_halted = not_halted * (1 - lambda_p) + + self.register_buffer("p_g", p_g) + self.kl_div = nn.KLDivLoss(reduction="batchmean") + + def forward(self, p): + """Compute loss. + + Parameters + ---------- + p : torch.Tensor + Probability of halting of shape `(steps, batch_size)`. + + Returns + ------- + loss : torch.Tensor + Scalar representing the regularization loss. + """ + steps, batch_size = p.shape + + p = p.transpose(0, 1) # (batch_size, max_steps) + + p_g_batch = self.p_g[None, :steps].expand_as( + p + ) # (batch_size, max_steps) + + return self.kl_div(p.log(), p_g_batch)