From 63cd063c7084b35bb120e0283765c6cbcd7a0c0a Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 14 Oct 2021 22:09:14 +0200 Subject: [PATCH] initial commit --- README.md | 12 + __pycache__/utils.cpython-39.pyc | Bin 0 -> 7910 bytes ....tfevents.1634241559.dodox-XPS-15.224268.0 | Bin 0 -> 346 bytes train.py | 398 ++++++++++++++++++ utils.py | 289 +++++++++++++ 5 files changed, 699 insertions(+) create mode 100644 README.md create mode 100644 __pycache__/utils.cpython-39.pyc create mode 100644 results/experiment_b/31742/events.out.tfevents.1634241559.dodox-XPS-15.224268.0 create mode 100644 train.py create mode 100644 utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..4f14281 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# Project PetriDish +Quick and dirty PoC for the idea behind Project Neuromorph. +Combines PonderNet and SparseLinear. +PonderNet stolen from https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/pondernet +SparseLinear stolen from https://pypi.org/project/sparselinear/ + +## Architecture +We a neural network comprised of a set of neurons that are connected using a set of synapses. Neurons that are close to each other (es use a 1D-Distance metric (Project Neuromorph will have approximate euclidean distance)) have a higher chance to have a synaptic connection. +We train this net like normal; but we also allow the structure of the synapctic connections to change during training. (Number of neurons remains constant; this is also variable in Neuromorph; Neuromorph will also use more advanced algorithms to decide where to spawn new synapses) +In every firing-cycle only a fraction of all neurons are allowed to fire (highest output) all others are inhibited. (In Project Neuromorph this will be less strict; low firing-rates will have higher dropout-chances and we discurage firing thought an additional loss) +Based on the PonderNet-Architecture we allow our network to 'think' as long as it wants about a given problem (well, ok; there is a maximum amount of firing-cycles to make training possbile) +The input's and output's of the network have a fixed length. (Project Neuromorph will also allow variable-length outputs like in a RNN) diff --git a/__pycache__/utils.cpython-39.pyc b/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203978fcd8633266caf95073a5a2cb8f4471595d GIT binary patch literal 7910 zcmc&(O^n<~6}HQEx4WljHW`Ne>@HwdLQF5ogh&X%Y9YI9*dK@Oy z?M{_@{^%AG$%2H4#0|k|=C~K`+z?kl;)diK0trRput)yG_sVW}_aqy_1>ACVxvE^h z?|t8U&wO#QY2Z5j!!2%o)G+=)53`qzhu3f`i6k^aGd4QKzu9LUW>TKTR^RT}I&a0! z9pfz{v_t2<5jw)YXLV|Lu7!0x*9F_Iqr^pt8#YkVC`%eBX@(0ZStv`Ip?TYAw;p0# zW4&$W^{YJNN@U2IF9%dQds%pR4Yzt3i7+~5XmnU;cC4@wTA_X4>|jP=?Y_}*!g}c9 zzZNp|-^>?pav5a@W$*mHX&6O`mv+5f5oD=U-bHUSO1M1mf|x56b)$f1QJQ#RF{ia= z)5C@Lcwb~fD*bq|y7Xh>E8ZW(LV2%vQIhF0dTV{y-xQKYo788i*At15JQHQli7#T& zKl$eE=upsm*_N2hv}&KE$)S*`-;a_DUgB4ee7_jpOQpA-CZd|gjVOuw!@jo~B|$t? z(XMEFJPEx%-+!z;sis`^TlGTI6iKGALOrdn;Mq`4rnlx@>v~C=d1@d6?4k&%QQ~c^ zUZzRX%k7Od-PwBJ;3Ov(<>`puz^%T4WMqyE-xx8Z=Ew?-p&`$OX2M1mO4!KOB`=Q{ zN^F!^AzE%T8=dT!@)q(|m1jGq`dR`}e%tsi#0ElPtv`em7?P50e^qHd}%A~i%Ve|ru&!H#eNpV=`g%>Q>NRH zIdy3uG8tWs)Yhe87R739aFDyck2A`A{|q`%EhI+0X4cp;TY6;MhiB#ld2P}-pjAvR z%G0g)k!FXTpcR{k&%Y%yePYwIEyZ!IZR&;KkFY{dzes2bdENJ8k@)_{@t_zIgB^bG zzNFY$&4eM={D<{$L5&f2`9OYv)gt;^DIE-*gC2m8)J)VNTo*> zTH=2NpmYwB5IQ649rh}6Xi=x0j?MGNn4zY&!;aizfYKRbY(r-3E~$R~Fs3zgZ(c^V z&_~(Fl{e79;H_p{_JBL3YH5{CGl9+G#g3J_8YMk%Ak#o7NWbzVlqc;M!c6Z-S}pdF z{LR2Urtq~!=jMK=sm0T$j&k$hm-i}@Xq)1rtkODa2Qj*<@9bp>yzFfoza*2}XDIp!0Nl20;aSh3& z7H!wGKLAnvd-Jqw>zV*Ha~CzI>j{wmeWPn?xNiXV3%GZ_Z`c5R4f&0mX%Y&#E;4=- zhk*aRRPGc9I?z?>84(>Q_kLhrN>zn~Q8E~27YjnEP&yrE&)fmpKB0$13}u>}WQHlH zl=an!TTvJSap!0KrJ^wA2OuIf!zxk)Ln#61ULfMQ>;#P9t4s_|4tk?xiWvmat%#nI za*qR^crO)u{F2vAW!WW<<8;s8;&Db=^oqBcrg6n7;EX$RC@vISLA~L{JuX9NXrXKR zbt040qdvS%ffiG0lY$3Z)G@s~;WNF0a`@LXfbHfm6Hg>=b&&9Z(!6~en4rYUJFU z>9+xR8Jm{fBM7w48u!fanwt77O??1Ckej0^t~PRzuV*gG0E-ZR*gGWruz;Em8mMbz zP1LyILf8ry9}q1rXu4Zk*PwLxv1^IK8loVg1D~GaM8sBMOT6O!&20C6^J}{PeC19w zIBOWud)JuJ`{_^U`uUX}7G~W4-H%@lUCmvCUXc> zf(Vg3D09{WprXnNxL5%Onc$l0*C6*QogxXf{{qi~Ewbu0VN8q@rTv2~J`gaLH>rtF zI6bkN6`9Rd|F;Be?xeu7nJHBD2TvQnn9}TUm9|{)rX#mEJShfJC}0Jkwd$M3BbY^1 z7(r=>n*hchK`<~<477b!JZX0bK0)FI@a*Qg*qGDH&k_D_Oz~gojfkczq)|6kHKd=I z}D^d-ZdActx3NMYFp$_#f=bgyCQ(Zt~42jaXi zKG78~J;L_f)^M!na)Z(1$*854AkuTwzHEDGUuo zrh$TT%&nN@NsA5+6J2KB9#@*^MY!hOp~_Mo=5+%)LA7&Tfx;p%k+;j^p&&E;GK__Q z50&utgvaAOegM-Q3XU1Yf~(AfkE{ouvE<2v%H05>o$(}thv;Zs&8-0sa~qbh$Q?aE zo8?ZJ9SlScrrV7Ip^0|A_wALRz4iOU`&Zg$L0U9-`5AO6PgB1vN%CfQm}uSMvHTR( zXhI==$=yWk`IwWE%k*f+Y0}FbI0M*a%{Z5+xUQmJKaIk?#s>p%tEQ#A4!I{^K5ro4 zkfg&nj&cSdB9fBiRAM$64|dv11)-gzqUKcRbJN!o)y*=|SsEtVW)q=L0^O+{_;IAN zPHmw7^VY1OGIu`ZZPy1e#?u~YwI7W;?ow(K*e%g}-yo}Q9r**A;`4;{;oB6;)Ox_p5i zUZmthlzf;H?X7$nxqPu?ZnA$T?5!#MlhjW0_(za?6}KYNh2QdF#0)wd5GaxMbz8MVEJ-@2k+@< z5TvQlmKd#oF3%F1qNY9x9n?7h^zzMksy8&6H<8FfS$PnM0irw`2RL~VZfH=)L#A;W z+6k9=r~^T3-WSqrtJ1+DS_N&(H1m-f?@si0j@U?fuR;5?(}Vsznj*nS)zu7#%)v>a zsB^&vvX1Dc?N=*EW-MSqHeW2N)q?N z<7YLs(__P7YuJDTZ&ih$d|CME>h4d-L-EAlr z2g}b`PTE?6pG|C1s&{c-B>o^MgiE#WD&i9+5?B*71==bh0AV0HipSKSY$Li#cM3r$ zM}8L4mN$t?3O$scM}Z{y27d{jv>S!QlT;F{lakXv?PozVZ&FP&OEVt(v}$>c3ZBzC zgDI4;?i2Vf2~vNHyNEcg5JI=`BcfKq+NJ%NBP>H}}K-Lt`SL7SwJ-8kjbZwK1=fvJcg;At1G zhZoW(2r_hFLJpJY?oh-BPn|&_ANUe67uFQ}2+|X`6t)sX3xcTTCQtI&OKBXec5})Y zFt;yxCD#uoHg*LFSjFu3j~kDLjf(gXbz25$Xo3$3w*x}($wdcFhe~DYTS%s;4(|;Z zsYA(DNFFc3z8XbeglB4x$TPG?4uZutJ`5my3>3BRz64}-LPSrF@i8E)!4HKnJRp8> zH9yqWd}1;U3bng~)`;$=S!sCX=FXEC>0fA8=$O_p8(P+a zozyakoleT~l$~ay2o$HL-F<_)B6ygMA{>F85RzwO9kNF_9=l!m=-7ASEku`&u+K*h zJZ$@z!tFMH4l48iWTYp!hcj&SpPXhwwx?0a`$PNeewS1{(K!7fpl}Ls;2M=18H{=O zK}twE?7RUzM5y2AN`9T5tvKzK;eJW*OvX53vGR+^cN&mNUkl|7R%#=akaqzGP<%I9 z1n#nBlhJo0pziDpNeI^_rm`t--&?4C5BCIo7nc?#m9-ZNXT~kU97$4$`KjWu z_@YJMP9;fPr=^2gB*DjbatnP9(XzBq-?;urYZ>HdGdpMRLGmE}C%j!1p<{I|Mu`Zd_zh!x9Z>#zc8`VoCU z{v^(lyZL^3s=vK1x!niXzjyu3f+;4zfYNOPQ$uwr1XHu&uj8#ZY7N(OhCj{~;rRs} z=R5^P`c0xhm)z(!Obbw%z$O~EJK=seB702|Gm}iH&gOt@7#&y(3_}sLf)G^-RhHN# zNy?Pj2vm&Iv4YDIWHQ4|HwXgsITLZpXtBsfXh@pjtti7D5(`FB;2<=59UGB!lBFp~ p%PhJjmARaQ%F#K)`lxq~m|# 0) == y_true_batch).to(torch.float32).mean() + ) + + metrics_single_["accuracy_halted"].append(accuracy_halted) + metrics_single_["halting_step"].append( + halting_step.to(torch.float).mean() + ) + + # Computing per step metrics (mean over samples in the batch) + accuracy = ( + ((y_pred_batch > 0) == y_true_batch[None, :]) + .to(torch.float32) + .mean(dim=1) + ) + + metrics_per_step_["accuracy"].append(accuracy) + metrics_per_step_["p"].append(p.mean(dim=1)) + + metrics_single = { + name: torch.stack(values).mean(dim=0).cpu().numpy() + for name, values in metrics_single_.items() + } + + metrics_per_step = { + name: torch.stack(values).mean(dim=0).cpu().numpy() + for name, values in metrics_per_step_.items() + } + + return metrics_single, metrics_per_step + + +def plot_distributions(target, predicted): + """Create a barplot. + + Parameters + ---------- + target, predicted : np.ndarray + Arrays of shape `(max_steps,)` representing the target and predicted + probability distributions. + + Returns + ------- + matplotlib.Figure + """ + support = list(range(1, len(target) + 1)) + + fig, ax = plt.subplots(dpi=140) + + ax.bar( + support, + target, + color="red", + label=f"Target - Geometric({target[0].item():.2f})", + ) + + ax.bar( + support, + predicted, + color="green", + width=0.4, + label="Predicted", + ) + + ax.set_ylim(0, 0.6) + ax.set_xticks(support) + ax.legend() + ax.grid() + + return fig + + +def plot_accuracy(accuracy): + """Create a barplot representing accuracy over different halting steps. + + Parameters + ---------- + accuracy : np.array + 1D array representing accuracy if we were to take the output after + the corresponding step. + + Returns + ------- + matplotlib.Figure + """ + support = list(range(1, len(accuracy) + 1)) + + fig, ax = plt.subplots(dpi=140) + + ax.bar( + support, + accuracy, + label="Accuracy over different steps", + ) + + ax.set_ylim(0, 1) + ax.set_xticks(support) + ax.legend() + ax.grid() + + return fig + + +def main(argv=None): + """CLI for training.""" + parser = ArgumentParser() + + parser.add_argument( + "log_folder", + type=str, + help="Folder where tensorboard logging is saved", + ) + parser.add_argument( + "--batch-size", + type=int, + default=128, + help="Batch size", + ) + parser.add_argument( + "--beta", + type=float, + default=0.01, + help="Regularization loss coefficient", + ) + parser.add_argument( + "-d", + "--device", + type=str, + choices={"cpu", "cuda"}, + default="cpu", + help="Device to use", + ) + parser.add_argument( + "--eval-frequency", + type=int, + default=10_000, + help="Evaluation is run every `eval_frequency` steps", + ) + parser.add_argument( + "--lambda-p", + type=float, + default=0.4, + help="True probability of success for a geometric distribution", + ) + parser.add_argument( + "--n-iter", + type=int, + default=1_000_000, + help="Number of gradient steps", + ) + parser.add_argument( + "--n-elems", + type=int, + default=64, + help="Number of elements", + ) + parser.add_argument( + "--n-hidden", + type=int, + default=64, + help="Number of hidden elements in the reccurent cell", + ) + parser.add_argument( + "--n-nonzero", + type=int, + nargs=2, + default=(None, None), + help="Lower and upper bound on nonzero elements in the training set", + ) + parser.add_argument( + "--max-steps", + type=int, + default=20, + help="Maximum number of pondering steps", + ) + + # Parameters + args = parser.parse_args(argv) + print(args) + + device = torch.device(args.device) + dtype = torch.float32 + n_eval_samples = 1000 + batch_size_eval = 50 + + if args.n_nonzero[0] is None and args.n_nonzero[1] is None: + threshold = int(0.3 * args.n_elems) + range_nonzero_easy = (1, threshold) + range_nonzero_hard = (args.n_elems - threshold, args.n_elems) + else: + range_nonzero_easy = (1, args.n_nonzero[1]) + range_nonzero_hard = (args.n_nonzero[1] + 1, args.n_elems) + + # Tensorboard + log_folder = pathlib.Path(args.log_folder) + writer = SummaryWriter(log_folder) + writer.add_text("parameters", json.dumps(vars(args))) + + # Prepare data + dataloader_train = DataLoader( + ParityDataset( + n_samples=args.batch_size * args.n_iter, + n_elems=args.n_elems, + n_nonzero_min=args.n_nonzero[0], + n_nonzero_max=args.n_nonzero[1], + ), + batch_size=args.batch_size, + ) # consider specifying `num_workers` for speedups + eval_dataloaders = { + "test": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=args.n_nonzero[0], + n_nonzero_max=args.n_nonzero[1], + ), + batch_size=batch_size_eval, + ), + f"{range_nonzero_easy[0]}_{range_nonzero_easy[1]}": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=range_nonzero_easy[0], + n_nonzero_max=range_nonzero_easy[1], + ), + batch_size=batch_size_eval, + ), + f"{range_nonzero_hard[0]}_{range_nonzero_hard[1]}": DataLoader( + ParityDataset( + n_samples=n_eval_samples, + n_elems=args.n_elems, + n_nonzero_min=range_nonzero_hard[0], + n_nonzero_max=range_nonzero_hard[1], + ), + batch_size=batch_size_eval, + ), + } + + # Model preparation + module = PetriDashNet( + n_in=args.n_elems, + n_neurons=args.n_hidden, + n_out=1, + max_steps=args.max_steps, + ) + module = module.to(device, dtype) + + # Loss preparation + loss_rec_inst = ReconstructionLoss( + nn.BCEWithLogitsLoss(reduction="none") + ).to(device, dtype) + + loss_reg_inst = RegularizationLoss( + lambda_p=args.lambda_p, + max_steps=args.max_steps, + ).to(device, dtype) + + # Optimizer + optimizer = torch.optim.Adam( + module.parameters(), + lr=0.0003, + ) + + # Training and evaluation loops + iterator = tqdm(enumerate(dataloader_train), total=args.n_iter) + for step, (x_batch, y_true_batch) in iterator: + x_batch = x_batch.to(device, dtype) + y_true_batch = y_true_batch.to(device, dtype) + + y_pred_batch, p, halting_step = module(x_batch) + + loss_rec = loss_rec_inst( + p, + y_pred_batch, + y_true_batch, + ) + + loss_reg = loss_reg_inst( + p, + ) + + loss_overall = loss_rec + args.beta * loss_reg + + optimizer.zero_grad() + loss_overall.backward() + torch.nn.utils.clip_grad_norm_(module.parameters(), 1) + optimizer.step() + + # Logging + writer.add_scalar("loss_rec", loss_rec, step) + writer.add_scalar("loss_reg", loss_reg, step) + writer.add_scalar("loss_overall", loss_overall, step) + + # Evaluation + if step % args.eval_frequency == 0: + module.eval() + + for dataloader_name, dataloader in eval_dataloaders.items(): + metrics_single, metrics_per_step = evaluate( + dataloader, + module, + ) + fig_dist = plot_distributions( + loss_reg_inst.p_g.cpu().numpy(), + metrics_per_step["p"], + ) + writer.add_figure( + f"distributions/{dataloader_name}", fig_dist, step + ) + + fig_acc = plot_accuracy(metrics_per_step["accuracy"]) + writer.add_figure( + f"accuracy_per_step/{dataloader_name}", fig_acc, step + ) + + for metric_name, metric_value in metrics_single.items(): + writer.add_scalar( + f"{metric_name}/{dataloader_name}", + metric_value, + step, + ) + + torch.save(module, log_folder / "checkpoint.pth") + + module.train() + + +if __name__ == "__main__": + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e4fdf41 --- /dev/null +++ b/utils.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from sparselinear import * + +class ParityDataset(Dataset): + """Parity of vectors - binary classification dataset. + + Parameters + ---------- + n_samples : int + Number of samples to generate. + + n_elems : int + Size of the vectors. + + n_nonzero_min, n_nonzero_max : int or None + Minimum (inclusive) and maximum (inclusive) number of nonzero + elements in the feature vector. If not specified then `(1, n_elem)`. + """ + + def __init__( + self, + n_samples, + n_elems, + n_nonzero_min=None, + n_nonzero_max=None, + ): + self.n_samples = n_samples + self.n_elems = n_elems + + self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min + self.n_nonzero_max = ( + n_elems if n_nonzero_max is None else n_nonzero_max + ) + + assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems + + def __len__(self): + """Get the number of samples.""" + return self.n_samples + + def __getitem__(self, idx): + """Get a feature vector and it's parity (target). + + Note that the generating process is random. + """ + x = torch.zeros((self.n_elems,)) + n_non_zero = torch.randint( + self.n_nonzero_min, self.n_nonzero_max + 1, (1,) + ).item() + x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1 + x = x[torch.randperm(self.n_elems)] + + y = (x == 1.0).sum() % 2 + + return x, y + + +class PetriDishNet(nn.Module): + """Network based on the PetriDish-Architecture + + Is capable of performing an uncontrolled intelligence-explosion. + May or may not lead to the technological singularity when executed. + Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental. + + Parameters + ---------- + n_in : int + Number of features in the input-vector. + + n_out : int + Number of features in the output-vector. + + n_neurons : int + Number of neurons. + + max_steps : int + Maximum number of steps the network can "ponder" for. + + allow_halting : bool + If True, then the forward pass is allowed to halt before + reaching the maximum steps. + + Attributes + ---------- + synapses : SparseLinear + Magic stuff + """ + + def __init__( + self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False + ): + super().__init__() + + assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')' + + self.n_in = n_in + self.n_out = n_out + self.n_neurons = n_neurons + self.max_steps = max_steps + self.allow_halting = allow_halting + + self.synapses = SparseLinear(n_neurons, n_neurons, + sparsity=0.9, + small_world=True, + dynamic=True, + deltaT=6000, + Tend=150000, + alpha=0.1, + max_size=1e8) + + self.ionDucts = ActivationSparsity(n_neurons, + alpha=0.1, + beta=1.5, + act_sparsity=0.65) + + def forward(self, x): + """Run forward pass. + + Parameters + ---------- + x : torch.Tensor + Batch of input features of shape `(batch_size, n_elems)`. + + Returns + ------- + y : torch.Tensor + Tensor of shape `(max_steps, batch_size, n_out)` representing + the predictions for each step and each sample. In case + `allow_halting=True` then the shape is + `(steps, batch_size, n_out)` where `1 <= steps <= max_steps`. + + p : torch.Tensor + Tensor of shape `(max_steps, batch_size)` representing + the halting probabilities. Sums over rows (fixing a sample) + are 1. In case `allow_halting=True` then the shape is + `(steps, batch_size)` where `1 <= steps <= max_steps`. + + halting_step : torch.Tensor + An integer for each sample in the batch that corresponds to + the step when it was halted. The shape is `(batch_size,)`. The + minimal value is 1 because we always run at least one step. + """ + batch_size, _ = x.shape + device = x.device + + state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0) + + un_halted_prob = x.new_ones(batch_size) + + y_list = [] + p_list = [] + + halting_step = torch.zeros( + batch_size, + dtype=torch.long, + device=device, + ) + + for n in range(1, self.max_steps + 1): + state = self.synapses(state) + if n == self.max_steps: + lambda_n = x.new_ones(batch_size) # (batch_size,) + else: + lambda_n = torch.sigmoid(state[:, 0]) + + # Store releavant outputs + y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)]) + p_list.append(un_halted_prob * lambda_n) + + halting_step = torch.maximum( + n + * (halting_step == 0) + * torch.bernoulli(lambda_n).to(torch.long), + halting_step, + ) + + # Prepare for next iteration + un_halted_prob = un_halted_prob * (1 - lambda_n) + + # Potentially stop if all samples halted + #if self.allow_halting and (halting_step > 0).sum() == batch_size: + if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size: + break + + y = torch.stack(y_list) + p = torch.stack(p_list) + + return y, p, halting_step + + +class ReconstructionLoss(nn.Module): + """Weighted average of per step losses. + + Parameters + ---------- + loss_func : callable + Loss function that accepts `y_pred` and `y_true` as arguments. Both + of these tensors have shape `(batch_size, n_out)`. It outputs a loss for + each sample in the batch. + """ + + def __init__(self, loss_func): + super().__init__() + + self.loss_func = loss_func + + def forward(self, p, y_pred, y_true): + """Compute loss. + + Parameters + ---------- + p : torch.Tensor + Probability of halting of shape `(max_steps, batch_size)`. + + y_pred : torch.Tensor + Predicted outputs of shape `(max_steps, batch_size, n_out)`. + + y_true : torch.Tensor + True targets of shape `(batch_size, n_out)`. + + Returns + ------- + loss : torch.Tensor + Scalar representing the reconstruction loss. It is nothing else + than a weighted sum of per step losses. + """ + max_steps, _ = p.shape + total_loss = p.new_tensor(0.0) + + for n in range(max_steps): + loss_per_sample = p[n] * self.loss_func( + y_pred[n], y_true + ) # (batch_size,) + total_loss = total_loss + loss_per_sample.mean() # (1,) + + return total_loss + + +class RegularizationLoss(nn.Module): + """Enforce halting distribution to ressemble the geometric distribution. + + Parameters + ---------- + lambda_p : float + The single parameter determining uniquely the geometric distribution. + Note that the expected value of this distribution is going to be + `1 / lambda_p`. + + max_steps : int + Maximum number of pondering steps. + """ + + def __init__(self, lambda_p, max_steps=20): + super().__init__() + + p_g = torch.zeros((max_steps,)) + not_halted = 1.0 + + for k in range(max_steps): + p_g[k] = not_halted * lambda_p + not_halted = not_halted * (1 - lambda_p) + + self.register_buffer("p_g", p_g) + self.kl_div = nn.KLDivLoss(reduction="batchmean") + + def forward(self, p): + """Compute loss. + + Parameters + ---------- + p : torch.Tensor + Probability of halting of shape `(steps, batch_size)`. + + Returns + ------- + loss : torch.Tensor + Scalar representing the regularization loss. + """ + steps, batch_size = p.shape + + p = p.transpose(0, 1) # (batch_size, max_steps) + + p_g_batch = self.p_g[None, :steps].expand_as( + p + ) # (batch_size, max_steps) + + return self.kl_div(p.log(), p_g_batch)