PriorConditionedAnnealing/priorConditionedAnnealing/noise.py

81 lines
3.0 KiB
Python
Raw Normal View History

import numpy as np
import torch as th
import colorednoise as cn
from torch.distributions import Normal
class Colored_Noise():
def __init__(self, known_shape=None, beta=1, num_samples=2**16, random_state=None):
assert known_shape, 'known_shape need to be defined for Colored Noise'
self.known_shape = known_shape
self.beta = beta
self.num_samples = num_samples
self.index = 0
self.reset(random_state=random_state)
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
assert shape == self.shape
sample = self.samples[:, self.index]
self.index = (self.index+1) % self.num_samples
return sample
def reset(self, random_state=None):
self.samples = cn.powerlaw_psd_gaussian(
self.beta, self.shape + (self.num_samples,), random_state=random_state)
class White_Noise():
def __init__(self, known_shape=None):
self.known_shape = known_shape
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
return th.Tensor(np.random.normal(0, 1, shape))
def get_colored_noise(beta, known_shape=None):
if beta == 0:
return White_Noise(known_shape)
else:
return Colored_Noise(known_shape, beta=beta)
class SDE_Noise():
def __init__(self, shape, latent_sde_dim=64, Base_Noise=White_Noise):
self.shape = shape
self.latent_sde_dim = latent_sde_dim
self.Base_Noise = Base_Noise
batch_size = self.shape[0]
self.weights_dist = self.Base_Noise(
(self.latent_sde_dim,) + self.shape)
self.weights_dist_batch = self.Base_Noise(
(batch_size, self.latent_sde_dim,) + self.shape)
def sample_weights(self):
# Reparametrization trick to pass gradients
self.exploration_mat = self.weights_dist.sample()
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist_batch.sample()
def __call__(self, latent: th.Tensor) -> th.Tensor:
latent_sde = latent.detach()
latent_sde = latent_sde[..., -self.sde_latent_dim:]
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
p = self.distribution
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
chol = th.diag_embed(self.distribution.stddev)
elif isinstance(p, th.distributions.MultivariateNormal):
chol = p.scale_tril
# Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
# Use batch matrix multiplication for efficient computation
# (batch_size, n_features) -> (batch_size, 1, n_features)
latent_sde = latent_sde.unsqueeze(dim=1)
# (batch_size, 1, n_actions)
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
return noise.squeeze(dim=1)