132 lines
4.8 KiB
Python
132 lines
4.8 KiB
Python
import numpy as np
|
|
import torch as th
|
|
import colorednoise as cn
|
|
from perlin_noise import PerlinNoise
|
|
from torch.distributions import Normal
|
|
|
|
|
|
class Colored_Noise():
|
|
def __init__(self, known_shape=None, beta=1, num_samples=2**14, random_state=None):
|
|
assert known_shape, 'known_shape need to be defined for Colored Noise'
|
|
self.known_shape = known_shape
|
|
self.compact_shape = np.prod(list(known_shape))
|
|
self.beta = beta
|
|
self.num_samples = num_samples # Actually very cheap...
|
|
self.index = 0
|
|
self.reset(random_state=random_state)
|
|
|
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
|
assert shape == self.known_shape
|
|
sample = self.samples[:, self.index]
|
|
self.index = (self.index+1) % self.num_samples
|
|
return th.Tensor(sample).view(self.known_shape)
|
|
|
|
def reset(self, random_state=None):
|
|
self.samples = cn.powerlaw_psd_gaussian(
|
|
self.beta, (self.compact_shape, self.num_samples), random_state=random_state)
|
|
|
|
|
|
class Pink_Noise(Colored_Noise):
|
|
def __init__(self, known_shape=None, num_samples=2**14, random_state=None):
|
|
super().__init__(known_shape=known_shape, beta=1, num_samples=num_samples, random_state=random_state)
|
|
|
|
|
|
class White_Noise():
|
|
def __init__(self, known_shape=None):
|
|
self.known_shape = known_shape
|
|
|
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
|
return th.Tensor(np.random.normal(0, 1, shape))
|
|
|
|
|
|
def get_colored_noise(beta, known_shape=None):
|
|
if beta == 0:
|
|
return White_Noise(known_shape)
|
|
elif beta == 1:
|
|
return Pink_Noise(known_shape)
|
|
else:
|
|
return Colored_Noise(known_shape, beta=beta)
|
|
|
|
|
|
class SDE_Noise():
|
|
def __init__(self, shape, latent_sde_dim=64, Base_Noise=White_Noise):
|
|
self.shape = shape
|
|
self.latent_sde_dim = latent_sde_dim
|
|
self.Base_Noise = Base_Noise
|
|
|
|
batch_size = self.shape[0]
|
|
self.weights_dist = self.Base_Noise(
|
|
(self.latent_sde_dim,) + self.shape)
|
|
self.weights_dist_batch = self.Base_Noise(
|
|
(batch_size, self.latent_sde_dim,) + self.shape)
|
|
|
|
def sample_weights(self):
|
|
# Reparametrization trick to pass gradients
|
|
self.exploration_mat = self.weights_dist.sample()
|
|
# Pre-compute matrices in case of parallel exploration
|
|
self.exploration_matrices = self.weights_dist_batch.sample()
|
|
|
|
def __call__(self, latent: th.Tensor) -> th.Tensor:
|
|
latent_sde = latent.detach()
|
|
latent_sde = latent_sde[..., -self.sde_latent_dim:]
|
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
|
|
|
p = self.distribution
|
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
|
chol = th.diag_embed(self.distribution.stddev)
|
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
|
chol = p.scale_tril
|
|
|
|
# Default case: only one exploration matrix
|
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
|
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
|
|
|
|
# Use batch matrix multiplication for efficient computation
|
|
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
|
latent_sde = latent_sde.unsqueeze(dim=1)
|
|
# (batch_size, 1, n_actions)
|
|
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
|
|
return noise.squeeze(dim=1)
|
|
|
|
|
|
class Perlin_Noise():
|
|
def __init__(self, known_shape=None, scale=0.1, octave=1):
|
|
self.known_shape = known_shape
|
|
self.scale = scale
|
|
self.octave = octave
|
|
self.magic = 3.141592653589 # Axis offset, should be (kinda) irrational
|
|
# We want to genrate samples, that approx ~N(0,1)
|
|
self.normal_factor = 14/99
|
|
self.reset()
|
|
|
|
def __call__(self, shape):
|
|
self.index += 1
|
|
noise = [self.noise([self.index*self.scale, self.magic*a]) / self.normal_factor
|
|
for a in range(shape[-1])]
|
|
return th.Tensor(noise)
|
|
|
|
def reset(self):
|
|
self.index = 0
|
|
self.noise = PerlinNoise(octaves=self.octave)
|
|
|
|
|
|
class Harmonic_Perlin_Noise():
|
|
def __init__(self, known_shape=None, scale=0.1, octaves=8):
|
|
self.known_shape = known_shape
|
|
self.scale = scale
|
|
if type(octaves) == int:
|
|
octaves = [1/(i+1) for i in range(octaves)]
|
|
octaves = np.array(octaves)
|
|
self.octaves = octaves / np.linalg.norm(octaves)
|
|
self.reset()
|
|
|
|
def __call__(self, shape):
|
|
harmonics = [noise(shape)*self.octaves[i] for i, noise in enumerate(self.noises)]
|
|
return sum(harmonics)
|
|
|
|
def reset(self):
|
|
self.index = 0
|
|
self.noises = []
|
|
for octave, amplitude in enumerate(self.octaves):
|
|
self.noises += [Perlin_Noise(known_shape=self.known_shape, scale=self.scale, octave=(octave+1))]
|