import torch import torch.nn as nn from torch.utils.data import Dataset from sparselinear import * class ParityDataset(Dataset): """Parity of vectors - binary classification dataset. Parameters ---------- n_samples : int Number of samples to generate. n_elems : int Size of the vectors. n_nonzero_min, n_nonzero_max : int or None Minimum (inclusive) and maximum (inclusive) number of nonzero elements in the feature vector. If not specified then `(1, n_elem)`. """ def __init__( self, n_samples, n_elems, n_nonzero_min=None, n_nonzero_max=None, ): self.n_samples = n_samples self.n_elems = n_elems self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min self.n_nonzero_max = ( n_elems if n_nonzero_max is None else n_nonzero_max ) assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems def __len__(self): """Get the number of samples.""" return self.n_samples def __getitem__(self, idx): """Get a feature vector and it's parity (target). Note that the generating process is random. """ x = torch.zeros((self.n_elems,)) n_non_zero = torch.randint( self.n_nonzero_min, self.n_nonzero_max + 1, (1,) ).item() x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1 x = x[torch.randperm(self.n_elems)] y = (x == 1.0).sum() % 2 return x, torch.stack([y]) class PetriDishNet(nn.Module): """Network based on the PetriDish-Architecture Is capable of performing an uncontrolled intelligence-explosion. May or may not lead to the technological singularity when executed. Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental. Parameters ---------- n_in : int Number of features in the input-vector. n_out : int Number of features in the output-vector. n_neurons : int Number of neurons. max_steps : int Maximum number of steps the network can "ponder" for. allow_halting : bool If True, then the forward pass is allowed to halt before reaching the maximum steps. Attributes ---------- synapses : SparseLinear Magic stuff """ def __init__( self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False ): super().__init__() assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')' self.n_in = n_in self.n_out = n_out self.n_neurons = n_neurons self.max_steps = max_steps self.allow_halting = allow_halting self.synapses = SparseLinear(n_neurons, n_neurons, sparsity=0.8, #sparsity=0.1, small_world=True, dynamic=True, #deltaT=6000, deltaT=250, Tend=150000, #Tend=5000, alpha=0.1, max_size=1e8) self.ionDucts = ActivationSparsity(alpha=0.1, beta=1.5, act_sparsity=0.75) #act_sparsity=0.10) def forward(self, x): """Run forward pass. Parameters ---------- x : torch.Tensor Batch of input features of shape `(batch_size, n_elems)`. Returns ------- y : torch.Tensor Tensor of shape `(max_steps, batch_size, n_out)` representing the predictions for each step and each sample. In case `allow_halting=True` then the shape is `(steps, batch_size, n_out)` where `1 <= steps <= max_steps`. p : torch.Tensor Tensor of shape `(max_steps, batch_size)` representing the halting probabilities. Sums over rows (fixing a sample) are 1. In case `allow_halting=True` then the shape is `(steps, batch_size)` where `1 <= steps <= max_steps`. halting_step : torch.Tensor An integer for each sample in the batch that corresponds to the step when it was halted. The shape is `(batch_size,)`. The minimal value is 1 because we always run at least one step. """ batch_size, _ = x.shape device = x.device state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0) un_halted_prob = x.new_ones(batch_size) y_list = [] p_list = [] halting_step = torch.zeros( batch_size, dtype=torch.long, device=device, ) for n in range(1, self.max_steps + 1): state = self.synapses(state) if n == self.max_steps: lambda_n = x.new_ones(batch_size) # (batch_size,) else: lambda_n = torch.sigmoid(state[:, 0]) # Store releavant outputs y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)]) p_list.append(un_halted_prob * lambda_n) halting_step = torch.maximum( n * (halting_step == 0) * torch.bernoulli(lambda_n).to(torch.long), halting_step, ) # Prepare for next iteration un_halted_prob = un_halted_prob * (1 - lambda_n) # Potentially stop if all samples halted #if self.allow_halting and (halting_step > 0).sum() == batch_size: if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size: break y = torch.stack(y_list) p = torch.stack(p_list) return y, p, halting_step class ReconstructionLoss(nn.Module): """Weighted average of per step losses. Parameters ---------- loss_func : callable Loss function that accepts `y_pred` and `y_true` as arguments. Both of these tensors have shape `(batch_size, n_out)`. It outputs a loss for each sample in the batch. """ def __init__(self, loss_func): super().__init__() self.loss_func = loss_func def forward(self, p, y_pred, y_true): """Compute loss. Parameters ---------- p : torch.Tensor Probability of halting of shape `(max_steps, batch_size)`. y_pred : torch.Tensor Predicted outputs of shape `(max_steps, batch_size, n_out)`. y_true : torch.Tensor True targets of shape `(batch_size, n_out)`. Returns ------- loss : torch.Tensor Scalar representing the reconstruction loss. It is nothing else than a weighted sum of per step losses. """ max_steps, _ = p.shape total_loss = p.new_tensor(0.0) for n in range(max_steps): loss_per_sample = p[n] * self.loss_func( y_pred[n], y_true ) # (batch_size,) total_loss = total_loss + loss_per_sample.mean() # (1,) return total_loss class RegularizationLoss(nn.Module): """Enforce halting distribution to ressemble the geometric distribution. Parameters ---------- lambda_p : float The single parameter determining uniquely the geometric distribution. Note that the expected value of this distribution is going to be `1 / lambda_p`. max_steps : int Maximum number of pondering steps. """ def __init__(self, lambda_p, max_steps=20): super().__init__() p_g = torch.zeros((max_steps,)) not_halted = 1.0 for k in range(max_steps): p_g[k] = not_halted * lambda_p not_halted = not_halted * (1 - lambda_p) self.register_buffer("p_g", p_g) self.kl_div = nn.KLDivLoss(reduction="batchmean") def forward(self, p): """Compute loss. Parameters ---------- p : torch.Tensor Probability of halting of shape `(steps, batch_size)`. Returns ------- loss : torch.Tensor Scalar representing the regularization loss. """ steps, batch_size = p.shape p = p.transpose(0, 1) # (batch_size, max_steps) p_g_batch = self.p_g[None, :steps].expand_as( p ) # (batch_size, max_steps) return self.kl_div(p.log(), p_g_batch)