2021-10-14 22:09:14 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from sparselinear import *
|
|
|
|
|
|
|
|
class ParityDataset(Dataset):
|
|
|
|
"""Parity of vectors - binary classification dataset.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
n_samples : int
|
|
|
|
Number of samples to generate.
|
|
|
|
|
|
|
|
n_elems : int
|
|
|
|
Size of the vectors.
|
|
|
|
|
|
|
|
n_nonzero_min, n_nonzero_max : int or None
|
|
|
|
Minimum (inclusive) and maximum (inclusive) number of nonzero
|
|
|
|
elements in the feature vector. If not specified then `(1, n_elem)`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
n_samples,
|
|
|
|
n_elems,
|
|
|
|
n_nonzero_min=None,
|
|
|
|
n_nonzero_max=None,
|
|
|
|
):
|
|
|
|
self.n_samples = n_samples
|
|
|
|
self.n_elems = n_elems
|
|
|
|
|
|
|
|
self.n_nonzero_min = 1 if n_nonzero_min is None else n_nonzero_min
|
|
|
|
self.n_nonzero_max = (
|
|
|
|
n_elems if n_nonzero_max is None else n_nonzero_max
|
|
|
|
)
|
|
|
|
|
|
|
|
assert 0 <= self.n_nonzero_min <= self.n_nonzero_max <= n_elems
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
"""Get the number of samples."""
|
|
|
|
return self.n_samples
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
"""Get a feature vector and it's parity (target).
|
|
|
|
|
|
|
|
Note that the generating process is random.
|
|
|
|
"""
|
|
|
|
x = torch.zeros((self.n_elems,))
|
|
|
|
n_non_zero = torch.randint(
|
|
|
|
self.n_nonzero_min, self.n_nonzero_max + 1, (1,)
|
|
|
|
).item()
|
|
|
|
x[:n_non_zero] = torch.randint(0, 2, (n_non_zero,)) * 2 - 1
|
|
|
|
x = x[torch.randperm(self.n_elems)]
|
|
|
|
|
|
|
|
y = (x == 1.0).sum() % 2
|
|
|
|
|
2021-10-15 13:34:38 +02:00
|
|
|
return x, torch.stack([y])
|
2021-10-14 22:09:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
class PetriDishNet(nn.Module):
|
|
|
|
"""Network based on the PetriDish-Architecture
|
|
|
|
|
|
|
|
Is capable of performing an uncontrolled intelligence-explosion.
|
|
|
|
May or may not lead to the technological singularity when executed.
|
|
|
|
Any similarity to any pop-culture AI, living or dead, or utopic or dystopic events, is purely coincidental.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
n_in : int
|
|
|
|
Number of features in the input-vector.
|
|
|
|
|
|
|
|
n_out : int
|
|
|
|
Number of features in the output-vector.
|
|
|
|
|
|
|
|
n_neurons : int
|
|
|
|
Number of neurons.
|
|
|
|
|
|
|
|
max_steps : int
|
|
|
|
Maximum number of steps the network can "ponder" for.
|
|
|
|
|
|
|
|
allow_halting : bool
|
|
|
|
If True, then the forward pass is allowed to halt before
|
|
|
|
reaching the maximum steps.
|
|
|
|
|
|
|
|
Attributes
|
|
|
|
----------
|
|
|
|
synapses : SparseLinear
|
|
|
|
Magic stuff
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self, n_in, n_out, n_neurons=64, max_steps=1024, allow_halting=False
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
assert n_neurons >= n_in + n_out + 1, 'Insufficient number of neurons (min: '+str(n_in + n_out + 1)+')'
|
|
|
|
|
|
|
|
self.n_in = n_in
|
|
|
|
self.n_out = n_out
|
|
|
|
self.n_neurons = n_neurons
|
|
|
|
self.max_steps = max_steps
|
|
|
|
self.allow_halting = allow_halting
|
|
|
|
|
|
|
|
self.synapses = SparseLinear(n_neurons, n_neurons,
|
2021-10-24 18:07:11 +02:00
|
|
|
sparsity=0.8,
|
|
|
|
#sparsity=0.1,
|
2021-10-14 22:09:14 +02:00
|
|
|
small_world=True,
|
|
|
|
dynamic=True,
|
2021-10-24 18:07:11 +02:00
|
|
|
#deltaT=6000,
|
|
|
|
deltaT=250,
|
2021-10-14 22:09:14 +02:00
|
|
|
Tend=150000,
|
2021-10-24 18:07:11 +02:00
|
|
|
#Tend=5000,
|
2021-10-14 22:09:14 +02:00
|
|
|
alpha=0.1,
|
|
|
|
max_size=1e8)
|
|
|
|
|
2021-10-15 13:34:38 +02:00
|
|
|
self.ionDucts = ActivationSparsity(alpha=0.1,
|
2021-10-14 22:09:14 +02:00
|
|
|
beta=1.5,
|
2021-10-24 18:07:11 +02:00
|
|
|
act_sparsity=0.75)
|
|
|
|
#act_sparsity=0.10)
|
2021-10-14 22:09:14 +02:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
"""Run forward pass.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
x : torch.Tensor
|
|
|
|
Batch of input features of shape `(batch_size, n_elems)`.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
y : torch.Tensor
|
|
|
|
Tensor of shape `(max_steps, batch_size, n_out)` representing
|
|
|
|
the predictions for each step and each sample. In case
|
|
|
|
`allow_halting=True` then the shape is
|
|
|
|
`(steps, batch_size, n_out)` where `1 <= steps <= max_steps`.
|
|
|
|
|
|
|
|
p : torch.Tensor
|
|
|
|
Tensor of shape `(max_steps, batch_size)` representing
|
|
|
|
the halting probabilities. Sums over rows (fixing a sample)
|
|
|
|
are 1. In case `allow_halting=True` then the shape is
|
|
|
|
`(steps, batch_size)` where `1 <= steps <= max_steps`.
|
|
|
|
|
|
|
|
halting_step : torch.Tensor
|
|
|
|
An integer for each sample in the batch that corresponds to
|
|
|
|
the step when it was halted. The shape is `(batch_size,)`. The
|
|
|
|
minimal value is 1 because we always run at least one step.
|
|
|
|
"""
|
|
|
|
batch_size, _ = x.shape
|
|
|
|
device = x.device
|
|
|
|
|
|
|
|
state = torch.nn.functional.pad(input=x, pad=(0, self.n_neurons - self.n_in), mode='constant', value=0)
|
|
|
|
|
|
|
|
un_halted_prob = x.new_ones(batch_size)
|
|
|
|
|
|
|
|
y_list = []
|
|
|
|
p_list = []
|
|
|
|
|
|
|
|
halting_step = torch.zeros(
|
|
|
|
batch_size,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
for n in range(1, self.max_steps + 1):
|
|
|
|
state = self.synapses(state)
|
|
|
|
if n == self.max_steps:
|
|
|
|
lambda_n = x.new_ones(batch_size) # (batch_size,)
|
|
|
|
else:
|
|
|
|
lambda_n = torch.sigmoid(state[:, 0])
|
|
|
|
|
2021-10-24 18:07:11 +02:00
|
|
|
|
2021-10-14 22:09:14 +02:00
|
|
|
# Store releavant outputs
|
|
|
|
y_list.append(state[:, (self.n_in+1):(self.n_in+self.n_out+1)])
|
|
|
|
p_list.append(un_halted_prob * lambda_n)
|
|
|
|
|
|
|
|
halting_step = torch.maximum(
|
|
|
|
n
|
|
|
|
* (halting_step == 0)
|
|
|
|
* torch.bernoulli(lambda_n).to(torch.long),
|
|
|
|
halting_step,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Prepare for next iteration
|
|
|
|
un_halted_prob = un_halted_prob * (1 - lambda_n)
|
|
|
|
|
|
|
|
# Potentially stop if all samples halted
|
|
|
|
#if self.allow_halting and (halting_step > 0).sum() == batch_size:
|
|
|
|
if self.allow_halting and (un_halted_prob < 0.01).sum() == batch_size:
|
|
|
|
break
|
|
|
|
|
|
|
|
y = torch.stack(y_list)
|
|
|
|
p = torch.stack(p_list)
|
|
|
|
|
|
|
|
return y, p, halting_step
|
|
|
|
|
|
|
|
|
|
|
|
class ReconstructionLoss(nn.Module):
|
|
|
|
"""Weighted average of per step losses.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
loss_func : callable
|
|
|
|
Loss function that accepts `y_pred` and `y_true` as arguments. Both
|
|
|
|
of these tensors have shape `(batch_size, n_out)`. It outputs a loss for
|
|
|
|
each sample in the batch.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, loss_func):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.loss_func = loss_func
|
|
|
|
|
|
|
|
def forward(self, p, y_pred, y_true):
|
|
|
|
"""Compute loss.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
p : torch.Tensor
|
|
|
|
Probability of halting of shape `(max_steps, batch_size)`.
|
|
|
|
|
|
|
|
y_pred : torch.Tensor
|
|
|
|
Predicted outputs of shape `(max_steps, batch_size, n_out)`.
|
|
|
|
|
|
|
|
y_true : torch.Tensor
|
|
|
|
True targets of shape `(batch_size, n_out)`.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
loss : torch.Tensor
|
|
|
|
Scalar representing the reconstruction loss. It is nothing else
|
|
|
|
than a weighted sum of per step losses.
|
|
|
|
"""
|
|
|
|
max_steps, _ = p.shape
|
|
|
|
total_loss = p.new_tensor(0.0)
|
|
|
|
|
|
|
|
for n in range(max_steps):
|
|
|
|
loss_per_sample = p[n] * self.loss_func(
|
|
|
|
y_pred[n], y_true
|
|
|
|
) # (batch_size,)
|
|
|
|
total_loss = total_loss + loss_per_sample.mean() # (1,)
|
|
|
|
|
|
|
|
return total_loss
|
|
|
|
|
|
|
|
|
|
|
|
class RegularizationLoss(nn.Module):
|
|
|
|
"""Enforce halting distribution to ressemble the geometric distribution.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
lambda_p : float
|
|
|
|
The single parameter determining uniquely the geometric distribution.
|
|
|
|
Note that the expected value of this distribution is going to be
|
|
|
|
`1 / lambda_p`.
|
|
|
|
|
|
|
|
max_steps : int
|
|
|
|
Maximum number of pondering steps.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, lambda_p, max_steps=20):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
p_g = torch.zeros((max_steps,))
|
|
|
|
not_halted = 1.0
|
|
|
|
|
|
|
|
for k in range(max_steps):
|
|
|
|
p_g[k] = not_halted * lambda_p
|
|
|
|
not_halted = not_halted * (1 - lambda_p)
|
|
|
|
|
|
|
|
self.register_buffer("p_g", p_g)
|
|
|
|
self.kl_div = nn.KLDivLoss(reduction="batchmean")
|
|
|
|
|
|
|
|
def forward(self, p):
|
|
|
|
"""Compute loss.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
p : torch.Tensor
|
|
|
|
Probability of halting of shape `(steps, batch_size)`.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
loss : torch.Tensor
|
|
|
|
Scalar representing the regularization loss.
|
|
|
|
"""
|
|
|
|
steps, batch_size = p.shape
|
|
|
|
|
|
|
|
p = p.transpose(0, 1) # (batch_size, max_steps)
|
|
|
|
|
|
|
|
p_g_batch = self.p_g[None, :steps].expand_as(
|
|
|
|
p
|
|
|
|
) # (batch_size, max_steps)
|
|
|
|
|
|
|
|
return self.kl_div(p.log(), p_g_batch)
|