diff --git a/priorConditionedAnnealing/noise.py b/priorConditionedAnnealing/noise.py new file mode 100644 index 0000000..1e0c22e --- /dev/null +++ b/priorConditionedAnnealing/noise.py @@ -0,0 +1,80 @@ +import numpy as np +import torch as th +import colorednoise as cn +from torch.distributions import Normal + + +class Colored_Noise(): + def __init__(self, known_shape=None, beta=1, num_samples=2**16, random_state=None): + assert known_shape, 'known_shape need to be defined for Colored Noise' + self.known_shape = known_shape + self.beta = beta + self.num_samples = num_samples + self.index = 0 + self.reset(random_state=random_state) + + def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: + assert shape == self.shape + sample = self.samples[:, self.index] + self.index = (self.index+1) % self.num_samples + return sample + + def reset(self, random_state=None): + self.samples = cn.powerlaw_psd_gaussian( + self.beta, self.shape + (self.num_samples,), random_state=random_state) + + +class White_Noise(): + def __init__(self, known_shape=None): + self.known_shape = known_shape + + def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor: + return th.Tensor(np.random.normal(0, 1, shape)) + + +def get_colored_noise(beta, known_shape=None): + if beta == 0: + return White_Noise(known_shape) + else: + return Colored_Noise(known_shape, beta=beta) + + +class SDE_Noise(): + def __init__(self, shape, latent_sde_dim=64, Base_Noise=White_Noise): + self.shape = shape + self.latent_sde_dim = latent_sde_dim + self.Base_Noise = Base_Noise + + batch_size = self.shape[0] + self.weights_dist = self.Base_Noise( + (self.latent_sde_dim,) + self.shape) + self.weights_dist_batch = self.Base_Noise( + (batch_size, self.latent_sde_dim,) + self.shape) + + def sample_weights(self): + # Reparametrization trick to pass gradients + self.exploration_mat = self.weights_dist.sample() + # Pre-compute matrices in case of parallel exploration + self.exploration_matrices = self.weights_dist_batch.sample() + + def __call__(self, latent: th.Tensor) -> th.Tensor: + latent_sde = latent.detach() + latent_sde = latent_sde[..., -self.sde_latent_dim:] + latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) + + p = self.distribution + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): + chol = th.diag_embed(self.distribution.stddev) + elif isinstance(p, th.distributions.MultivariateNormal): + chol = p.scale_tril + + # Default case: only one exploration matrix + if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): + return (th.mm(latent_sde, self.exploration_mat) @ chol)[0] + + # Use batch matrix multiplication for efficient computation + # (batch_size, n_features) -> (batch_size, 1, n_features) + latent_sde = latent_sde.unsqueeze(dim=1) + # (batch_size, 1, n_actions) + noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol) + return noise.squeeze(dim=1) diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 8a8234d..dd0a775 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -8,6 +8,8 @@ from stable_baselines3.common.distributions import sum_independent_dims from torch.distributions import Normal import torch.nn.functional as F +from priorConditionedAnnealing import noise + class Par_Strength(Enum): SCALAR = 'SCALAR' @@ -77,12 +79,10 @@ class PCA_Distribution(SB3_Distribution): init_std: int = 1, window: int = 64, epsilon: float = 1e-6, - use_sde: bool = False, + Base_Noise=noise.White_Noise ): super().__init__() - assert use_sde == False, 'PCA with SDE is not implemented' - self.action_dim = action_dim self.kernel_func = cast_to_kernel(kernel_func) self.par_strength = cast_to_enum(par_strength, Par_Strength) @@ -90,6 +90,11 @@ class PCA_Distribution(SB3_Distribution): self.window = window self.epsilon = epsilon + if Base_Noise.__class__ != noise.White_Noise: + print('[!] Non-White Noise was not yet tested!') + + self.base_noise = Base_Noise((1, )+action_dim) + # Premature optimization is the root of all evil self._build_conditioner() # *Optimizes it anyways* @@ -113,11 +118,12 @@ class PCA_Distribution(SB3_Distribution): def entropy(self) -> th.Tensor: return sum_independent_dims(self.distribution.entropy()) - def sample(self, traj: th.Tensor) -> th.Tensor: + def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor: pi_mean, pi_std = self.distribution.mean, self.distribution.scale rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std) eta = self._get_rigged(pi_mean, pi_std, - rho_mean, rho_std) + rho_mean, rho_std, + epsilon) # reparameterization with rigged samples actions = pi_mean + pi_std * eta self.gaussian_actions = actions @@ -126,9 +132,10 @@ class PCA_Distribution(SB3_Distribution): def is_contextual(self): return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG] - def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std): + def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None): with th.no_grad(): - epsilon = th.Tensor(np.random.normal(0, 1, pi_mean.shape)) + if epsilon == None: + epsilon = self.base_noise(pi_mean.shape) Delta = rho_mean - pi_mean Pi_mu = 1 / pi_std