289 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			289 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import torch.nn as nn
 | |
| from torch.utils.data import Dataset
 | |
| from sparselinear import *
 | |
| 
 | |
| class ParityDataset(Dataset):
 | |
|     """Parity of vectors - binary classification dataset.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     n_samples : int
 | |
|         Number of samples to generate.
 | |
| 
 | |
|     n_elems : int
 | |
|         Size of the vectors.
 | |
| 
 | |
|     n_nonzero_min, n_nonzero_max : int or None
 | |
|         Minimum (inclusive) and maximum (inclusive) number of nonzero
 | |
|         elements in the feature vector. If not specified then `(1, n_elem)`.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         n_samples,
 | |
|         n_elems,
 | |
|         n_nonzero_min=None,
 | |
|         n_nonzero_max=None,
 | |
|     ):
 | |
|         self.n_samples = n_samples
 | |
|         self.n_elems = n_elems
 | |
| 
 | |
|         self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min
 | |
|         self.n_nonzero_max = (
 | |
|             n_elems if n_nonzero_max is None else n_nonzero_max
 | |
|         )
 | |
| 
 | |
|         assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems
 | |
| 
 | |
|     def __len__(self):
 | |
|         """Get the number of samples."""
 | |
|         return self.n_samples
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         """Get a feature vector and it's parity (target).
 | |
| 
 | |
|         Note that the generating process is random.
 | |
|         """
 | |
|         x = torch.zeros((self.n_elems,))
 | |
|         n_non_zero = torch.randint(
 | |
|             self.n_nonzero_min, self.n_nonzero_max + 1, (1,)
 | |
|         ).item()
 | |
|         x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
 | |
|         x = x[torch.randperm(self.n_elems)]
 | |
| 
 | |
|         y = (x == 1.0).sum() % 2
 | |
| 
 | |
|         return x, torch.stack([y])
 | |
| 
 | |
| 
 | |
| class PetriDishNet(nn.Module):
 | |
|     """Network based on the PetriDish-Architecture
 | |
| 
 | |
|     Is capable of performing an uncontrolled intelligence-explosion.
 | |
|     May or may not lead to the technological singularity when executed.
 | |
|     Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     n_in : int
 | |
|         Number of features in the input-vector.
 | |
| 
 | |
|     n_out : int
 | |
|         Number of features in the output-vector.
 | |
| 
 | |
|     n_neurons : int
 | |
|         Number of neurons.
 | |
| 
 | |
|     max_steps : int
 | |
|         Maximum number of steps the network can "ponder" for.
 | |
| 
 | |
|     allow_halting : bool
 | |
|         If True, then the forward pass is allowed to halt before
 | |
|         reaching the maximum steps.
 | |
| 
 | |
|     Attributes
 | |
|     ----------
 | |
|     synapses : SparseLinear
 | |
|         Magic stuff
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False
 | |
|     ):
 | |
|         super().__init__()
 | |
| 
 | |
|         assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')'
 | |
| 
 | |
|         self.n_in = n_in
 | |
|         self.n_out = n_out
 | |
|         self.n_neurons = n_neurons
 | |
|         self.max_steps = max_steps
 | |
|         self.allow_halting = allow_halting
 | |
| 
 | |
|         self.synapses = SparseLinear(n_neurons, n_neurons,
 | |
|                                      sparsity=0.9,
 | |
|                                      small_world=True,
 | |
|                                      dynamic=True,
 | |
|                                      deltaT=6000,
 | |
|                                      Tend=150000,
 | |
|                                      alpha=0.1,
 | |
|                                      max_size=1e8)
 | |
| 
 | |
|         self.ionDucts = ActivationSparsity(alpha=0.1,
 | |
|                                            beta=1.5,
 | |
|                                            act_sparsity=0.65)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         """Run forward pass.
 | |
| 
 | |
|         Parameters
 | |
|         ----------
 | |
|         x : torch.Tensor
 | |
|             Batch of input features of shape `(batch_size, n_elems)`.
 | |
| 
 | |
|         Returns
 | |
|         -------
 | |
|         y : torch.Tensor
 | |
|             Tensor of shape `(max_steps, batch_size, n_out)` representing
 | |
|             the predictions for each step and each sample. In case
 | |
|             `allow_halting=True` then the shape is
 | |
|             `(steps, batch_size, n_out)` where `1 <= steps <= max_steps`.
 | |
| 
 | |
|         p : torch.Tensor
 | |
|             Tensor of shape `(max_steps, batch_size)` representing
 | |
|             the halting probabilities. Sums over rows (fixing a sample)
 | |
|             are 1. In case `allow_halting=True` then the shape is
 | |
|             `(steps, batch_size)` where `1 <= steps <= max_steps`.
 | |
| 
 | |
|         halting_step : torch.Tensor
 | |
|             An integer for each sample in the batch that corresponds to
 | |
|             the step when it was halted. The shape is `(batch_size,)`. The
 | |
|             minimal value is 1 because we always run at least one step.
 | |
|         """
 | |
|         batch_size, _ = x.shape
 | |
|         device = x.device
 | |
| 
 | |
|         state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0)
 | |
| 
 | |
|         un_halted_prob = x.new_ones(batch_size)
 | |
| 
 | |
|         y_list = []
 | |
|         p_list = []
 | |
| 
 | |
|         halting_step = torch.zeros(
 | |
|             batch_size,
 | |
|             dtype=torch.long,
 | |
|             device=device,
 | |
|         )
 | |
| 
 | |
|         for n in range(1, self.max_steps + 1):
 | |
|             state = self.synapses(state)
 | |
|             if n == self.max_steps:
 | |
|                 lambda_n = x.new_ones(batch_size)  # (batch_size,)
 | |
|             else:
 | |
|                 lambda_n = torch.sigmoid(state[:, 0])
 | |
| 
 | |
|             # Store releavant outputs
 | |
|             y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
 | |
|             p_list.append(un_halted_prob * lambda_n)
 | |
| 
 | |
|             halting_step = torch.maximum(
 | |
|                 n
 | |
|                 * (halting_step == 0)
 | |
|                 * torch.bernoulli(lambda_n).to(torch.long),
 | |
|                 halting_step,
 | |
|             )
 | |
| 
 | |
|             # Prepare for next iteration
 | |
|             un_halted_prob = un_halted_prob * (1 - lambda_n)
 | |
| 
 | |
|             # Potentially stop if all samples halted
 | |
|             #if self.allow_halting and (halting_step > 0).sum() == batch_size:
 | |
|             if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size:
 | |
|                 break
 | |
| 
 | |
|         y = torch.stack(y_list)
 | |
|         p = torch.stack(p_list)
 | |
| 
 | |
|         return y, p, halting_step
 | |
| 
 | |
| 
 | |
| class ReconstructionLoss(nn.Module):
 | |
|     """Weighted average of per step losses.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     loss_func : callable
 | |
|         Loss function that accepts `y_pred` and `y_true` as arguments. Both
 | |
|         of these tensors have shape `(batch_size, n_out)`. It outputs a loss for
 | |
|         each sample in the batch.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, loss_func):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.loss_func = loss_func
 | |
| 
 | |
|     def forward(self, p, y_pred, y_true):
 | |
|         """Compute loss.
 | |
| 
 | |
|         Parameters
 | |
|         ----------
 | |
|         p : torch.Tensor
 | |
|             Probability of halting of shape `(max_steps, batch_size)`.
 | |
| 
 | |
|         y_pred : torch.Tensor
 | |
|             Predicted outputs of shape `(max_steps, batch_size, n_out)`.
 | |
| 
 | |
|         y_true : torch.Tensor
 | |
|             True targets of shape `(batch_size, n_out)`.
 | |
| 
 | |
|         Returns
 | |
|         -------
 | |
|         loss : torch.Tensor
 | |
|             Scalar representing the reconstruction loss. It is nothing else
 | |
|             than a weighted sum of per step losses.
 | |
|         """
 | |
|         max_steps, _ = p.shape
 | |
|         total_loss = p.new_tensor(0.0)
 | |
| 
 | |
|         for n in range(max_steps):
 | |
|             loss_per_sample = p[n] * self.loss_func(
 | |
|                 y_pred[n], y_true
 | |
|             )  # (batch_size,)
 | |
|             total_loss = total_loss + loss_per_sample.mean()  # (1,)
 | |
| 
 | |
|         return total_loss
 | |
| 
 | |
| 
 | |
| class RegularizationLoss(nn.Module):
 | |
|     """Enforce halting distribution to ressemble the geometric distribution.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     lambda_p : float
 | |
|         The single parameter determining uniquely the geometric distribution.
 | |
|         Note that the expected value of this distribution is going to be
 | |
|         `1 / lambda_p`.
 | |
| 
 | |
|     max_steps : int
 | |
|         Maximum number of pondering steps.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, lambda_p, max_steps=20):
 | |
|         super().__init__()
 | |
| 
 | |
|         p_g = torch.zeros((max_steps,))
 | |
|         not_halted = 1.0
 | |
| 
 | |
|         for k in range(max_steps):
 | |
|             p_g[k] = not_halted * lambda_p
 | |
|             not_halted = not_halted * (1 - lambda_p)
 | |
| 
 | |
|         self.register_buffer("p_g", p_g)
 | |
|         self.kl_div = nn.KLDivLoss(reduction="batchmean")
 | |
| 
 | |
|     def forward(self, p):
 | |
|         """Compute loss.
 | |
| 
 | |
|         Parameters
 | |
|         ----------
 | |
|         p : torch.Tensor
 | |
|             Probability of halting of shape `(steps, batch_size)`.
 | |
| 
 | |
|         Returns
 | |
|         -------
 | |
|         loss : torch.Tensor
 | |
|             Scalar representing the regularization loss.
 | |
|         """
 | |
|         steps, batch_size = p.shape
 | |
| 
 | |
|         p = p.transpose(0, 1)  # (batch_size, max_steps)
 | |
| 
 | |
|         p_g_batch = self.p_g[None, :steps].expand_as(
 | |
|             p
 | |
|         )  # (batch_size, max_steps)
 | |
| 
 | |
|         return self.kl_div(p.log(), p_g_batch)
 |