Implemented Support for different base distributions
This commit is contained in:
parent
44b34fe12b
commit
b7ab7d0664
80
priorConditionedAnnealing/noise.py
Normal file
80
priorConditionedAnnealing/noise.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import colorednoise as cn
|
||||||
|
from torch.distributions import Normal
|
||||||
|
|
||||||
|
|
||||||
|
class Colored_Noise():
|
||||||
|
def __init__(self, known_shape=None, beta=1, num_samples=2**16, random_state=None):
|
||||||
|
assert known_shape, 'known_shape need to be defined for Colored Noise'
|
||||||
|
self.known_shape = known_shape
|
||||||
|
self.beta = beta
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.index = 0
|
||||||
|
self.reset(random_state=random_state)
|
||||||
|
|
||||||
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||||
|
assert shape == self.shape
|
||||||
|
sample = self.samples[:, self.index]
|
||||||
|
self.index = (self.index+1) % self.num_samples
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def reset(self, random_state=None):
|
||||||
|
self.samples = cn.powerlaw_psd_gaussian(
|
||||||
|
self.beta, self.shape + (self.num_samples,), random_state=random_state)
|
||||||
|
|
||||||
|
|
||||||
|
class White_Noise():
|
||||||
|
def __init__(self, known_shape=None):
|
||||||
|
self.known_shape = known_shape
|
||||||
|
|
||||||
|
def __call__(self, shape, latent: th.Tensor = None) -> th.Tensor:
|
||||||
|
return th.Tensor(np.random.normal(0, 1, shape))
|
||||||
|
|
||||||
|
|
||||||
|
def get_colored_noise(beta, known_shape=None):
|
||||||
|
if beta == 0:
|
||||||
|
return White_Noise(known_shape)
|
||||||
|
else:
|
||||||
|
return Colored_Noise(known_shape, beta=beta)
|
||||||
|
|
||||||
|
|
||||||
|
class SDE_Noise():
|
||||||
|
def __init__(self, shape, latent_sde_dim=64, Base_Noise=White_Noise):
|
||||||
|
self.shape = shape
|
||||||
|
self.latent_sde_dim = latent_sde_dim
|
||||||
|
self.Base_Noise = Base_Noise
|
||||||
|
|
||||||
|
batch_size = self.shape[0]
|
||||||
|
self.weights_dist = self.Base_Noise(
|
||||||
|
(self.latent_sde_dim,) + self.shape)
|
||||||
|
self.weights_dist_batch = self.Base_Noise(
|
||||||
|
(batch_size, self.latent_sde_dim,) + self.shape)
|
||||||
|
|
||||||
|
def sample_weights(self):
|
||||||
|
# Reparametrization trick to pass gradients
|
||||||
|
self.exploration_mat = self.weights_dist.sample()
|
||||||
|
# Pre-compute matrices in case of parallel exploration
|
||||||
|
self.exploration_matrices = self.weights_dist_batch.sample()
|
||||||
|
|
||||||
|
def __call__(self, latent: th.Tensor) -> th.Tensor:
|
||||||
|
latent_sde = latent.detach()
|
||||||
|
latent_sde = latent_sde[..., -self.sde_latent_dim:]
|
||||||
|
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
|
||||||
|
|
||||||
|
p = self.distribution
|
||||||
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||||
|
chol = th.diag_embed(self.distribution.stddev)
|
||||||
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
|
chol = p.scale_tril
|
||||||
|
|
||||||
|
# Default case: only one exploration matrix
|
||||||
|
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
|
||||||
|
return (th.mm(latent_sde, self.exploration_mat) @ chol)[0]
|
||||||
|
|
||||||
|
# Use batch matrix multiplication for efficient computation
|
||||||
|
# (batch_size, n_features) -> (batch_size, 1, n_features)
|
||||||
|
latent_sde = latent_sde.unsqueeze(dim=1)
|
||||||
|
# (batch_size, 1, n_actions)
|
||||||
|
noise = th.bmm(th.bmm(latent_sde, self.exploration_matrices), chol)
|
||||||
|
return noise.squeeze(dim=1)
|
@ -8,6 +8,8 @@ from stable_baselines3.common.distributions import sum_independent_dims
|
|||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from priorConditionedAnnealing import noise
|
||||||
|
|
||||||
|
|
||||||
class Par_Strength(Enum):
|
class Par_Strength(Enum):
|
||||||
SCALAR = 'SCALAR'
|
SCALAR = 'SCALAR'
|
||||||
@ -77,12 +79,10 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
init_std: int = 1,
|
init_std: int = 1,
|
||||||
window: int = 64,
|
window: int = 64,
|
||||||
epsilon: float = 1e-6,
|
epsilon: float = 1e-6,
|
||||||
use_sde: bool = False,
|
Base_Noise=noise.White_Noise
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert use_sde == False, 'PCA with SDE is not implemented'
|
|
||||||
|
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.kernel_func = cast_to_kernel(kernel_func)
|
self.kernel_func = cast_to_kernel(kernel_func)
|
||||||
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
self.par_strength = cast_to_enum(par_strength, Par_Strength)
|
||||||
@ -90,6 +90,11 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
self.window = window
|
self.window = window
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
if Base_Noise.__class__ != noise.White_Noise:
|
||||||
|
print('[!] Non-White Noise was not yet tested!')
|
||||||
|
|
||||||
|
self.base_noise = Base_Noise((1, )+action_dim)
|
||||||
|
|
||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
@ -113,11 +118,12 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.entropy())
|
return sum_independent_dims(self.distribution.entropy())
|
||||||
|
|
||||||
def sample(self, traj: th.Tensor) -> th.Tensor:
|
def sample(self, traj: th.Tensor, epsilon=None) -> th.Tensor:
|
||||||
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
||||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||||
eta = self._get_rigged(pi_mean, pi_std,
|
eta = self._get_rigged(pi_mean, pi_std,
|
||||||
rho_mean, rho_std)
|
rho_mean, rho_std,
|
||||||
|
epsilon)
|
||||||
# reparameterization with rigged samples
|
# reparameterization with rigged samples
|
||||||
actions = pi_mean + pi_std * eta
|
actions = pi_mean + pi_std * eta
|
||||||
self.gaussian_actions = actions
|
self.gaussian_actions = actions
|
||||||
@ -126,9 +132,10 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
def is_contextual(self):
|
def is_contextual(self):
|
||||||
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
|
||||||
|
|
||||||
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std):
|
def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std, epsilon=None):
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
epsilon = th.Tensor(np.random.normal(0, 1, pi_mean.shape))
|
if epsilon == None:
|
||||||
|
epsilon = self.base_noise(pi_mean.shape)
|
||||||
|
|
||||||
Delta = rho_mean - pi_mean
|
Delta = rho_mean - pi_mean
|
||||||
Pi_mu = 1 / pi_std
|
Pi_mu = 1 / pi_std
|
||||||
|
Loading…
Reference in New Issue
Block a user