petriDish/utils.py

294 lines
8.7 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from sparselinear import *
class ParityDataset(Dataset):
"""Parity of vectors - binary classification dataset.
Parameters
----------
n_samples : int
Number of samples to generate.
n_elems : int
Size of the vectors.
n_nonzero_min, n_nonzero_max : int or None
Minimum (inclusive) and maximum (inclusive) number of nonzero
elements in the feature vector. If not specified then `(1, n_elem)`.
"""
def __init__(
self,
n_samples,
n_elems,
n_nonzero_min=None,
n_nonzero_max=None,
):
self.n_samples = n_samples
self.n_elems = n_elems
self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min
self.n_nonzero_max = (
n_elems if n_nonzero_max is None else n_nonzero_max
)
assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems
def __len__(self):
"""Get the number of samples."""
return self.n_samples
def __getitem__(self, idx):
"""Get a feature vector and it's parity (target).
Note that the generating process is random.
"""
x = torch.zeros((self.n_elems,))
n_non_zero = torch.randint(
self.n_nonzero_min, self.n_nonzero_max + 1, (1,)
).item()
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
x = x[torch.randperm(self.n_elems)]
y = (x == 1.0).sum() % 2
return x, torch.stack([y])
class PetriDishNet(nn.Module):
"""Network based on the PetriDish-Architecture
Is capable of performing an uncontrolled intelligence-explosion.
May or may not lead to the technological singularity when executed.
Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental.
Parameters
----------
n_in : int
Number of features in the input-vector.
n_out : int
Number of features in the output-vector.
n_neurons : int
Number of neurons.
max_steps : int
Maximum number of steps the network can "ponder" for.
allow_halting : bool
If True, then the forward pass is allowed to halt before
reaching the maximum steps.
Attributes
----------
synapses : SparseLinear
Magic stuff
"""
def __init__(
self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False
):
super().__init__()
assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')'
self.n_in = n_in
self.n_out = n_out
self.n_neurons = n_neurons
self.max_steps = max_steps
self.allow_halting = allow_halting
self.synapses = SparseLinear(n_neurons, n_neurons,
sparsity=0.8,
#sparsity=0.1,
small_world=True,
dynamic=True,
#deltaT=6000,
deltaT=250,
Tend=150000,
#Tend=5000,
alpha=0.1,
max_size=1e8)
self.ionDucts = ActivationSparsity(alpha=0.1,
beta=1.5,
act_sparsity=0.75)
#act_sparsity=0.10)
def forward(self, x):
"""Run forward pass.
Parameters
----------
x : torch.Tensor
Batch of input features of shape `(batch_size, n_elems)`.
Returns
-------
y : torch.Tensor
Tensor of shape `(max_steps, batch_size, n_out)` representing
the predictions for each step and each sample. In case
`allow_halting=True` then the shape is
`(steps, batch_size, n_out)` where `1 <= steps <= max_steps`.
p : torch.Tensor
Tensor of shape `(max_steps, batch_size)` representing
the halting probabilities. Sums over rows (fixing a sample)
are 1. In case `allow_halting=True` then the shape is
`(steps, batch_size)` where `1 <= steps <= max_steps`.
halting_step : torch.Tensor
An integer for each sample in the batch that corresponds to
the step when it was halted. The shape is `(batch_size,)`. The
minimal value is 1 because we always run at least one step.
"""
batch_size, _ = x.shape
device = x.device
state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0)
un_halted_prob = x.new_ones(batch_size)
y_list = []
p_list = []
halting_step = torch.zeros(
batch_size,
dtype=torch.long,
device=device,
)
for n in range(1, self.max_steps + 1):
state = self.synapses(state)
if n == self.max_steps:
lambda_n = x.new_ones(batch_size) # (batch_size,)
else:
lambda_n = torch.sigmoid(state[:, 0])
# Store releavant outputs
y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
p_list.append(un_halted_prob * lambda_n)
halting_step = torch.maximum(
n
* (halting_step == 0)
* torch.bernoulli(lambda_n).to(torch.long),
halting_step,
)
# Prepare for next iteration
un_halted_prob = un_halted_prob * (1 - lambda_n)
# Potentially stop if all samples halted
#if self.allow_halting and (halting_step > 0).sum() == batch_size:
if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size:
break
y = torch.stack(y_list)
p = torch.stack(p_list)
return y, p, halting_step
class ReconstructionLoss(nn.Module):
"""Weighted average of per step losses.
Parameters
----------
loss_func : callable
Loss function that accepts `y_pred` and `y_true` as arguments. Both
of these tensors have shape `(batch_size, n_out)`. It outputs a loss for
each sample in the batch.
"""
def __init__(self, loss_func):
super().__init__()
self.loss_func = loss_func
def forward(self, p, y_pred, y_true):
"""Compute loss.
Parameters
----------
p : torch.Tensor
Probability of halting of shape `(max_steps, batch_size)`.
y_pred : torch.Tensor
Predicted outputs of shape `(max_steps, batch_size, n_out)`.
y_true : torch.Tensor
True targets of shape `(batch_size, n_out)`.
Returns
-------
loss : torch.Tensor
Scalar representing the reconstruction loss. It is nothing else
than a weighted sum of per step losses.
"""
max_steps, _ = p.shape
total_loss = p.new_tensor(0.0)
for n in range(max_steps):
loss_per_sample = p[n] * self.loss_func(
y_pred[n], y_true
) # (batch_size,)
total_loss = total_loss + loss_per_sample.mean() # (1,)
return total_loss
class RegularizationLoss(nn.Module):
"""Enforce halting distribution to ressemble the geometric distribution.
Parameters
----------
lambda_p : float
The single parameter determining uniquely the geometric distribution.
Note that the expected value of this distribution is going to be
`1 / lambda_p`.
max_steps : int
Maximum number of pondering steps.
"""
def __init__(self, lambda_p, max_steps=20):
super().__init__()
p_g = torch.zeros((max_steps,))
not_halted = 1.0
for k in range(max_steps):
p_g[k] = not_halted * lambda_p
not_halted = not_halted * (1 - lambda_p)
self.register_buffer("p_g", p_g)
self.kl_div = nn.KLDivLoss(reduction="batchmean")
def forward(self, p):
"""Compute loss.
Parameters
----------
p : torch.Tensor
Probability of halting of shape `(steps, batch_size)`.
Returns
-------
loss : torch.Tensor
Scalar representing the regularization loss.
"""
steps, batch_size = p.shape
p = p.transpose(0, 1) # (batch_size, max_steps)
p_g_batch = self.p_g[None, :steps].expand_as(
p
) # (batch_size, max_steps)
return self.kl_div(p.log(), p_g_batch)