Initial commit
This commit is contained in:
		
						commit
						1d2defcb27
					
				
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					secret.json
 | 
				
			||||||
 | 
					secret.json*
 | 
				
			||||||
 | 
					.venv
 | 
				
			||||||
 | 
					__pycache__
 | 
				
			||||||
 | 
					*/__pycache__
 | 
				
			||||||
 | 
					*/*/__pycache__
 | 
				
			||||||
 | 
					*.pyc
 | 
				
			||||||
							
								
								
									
										3
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					# Prior Conditioned Annealing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Because Movement Primitives are ugly and regular Step-Based is incompetent at exploration.
 | 
				
			||||||
							
								
								
									
										284
									
								
								pca.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								pca.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,284 @@
 | 
				
			|||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch as th
 | 
				
			||||||
 | 
					import scipy.spatial
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from stable_baselines3.common.distributions import Distribution as SB3_Distribution
 | 
				
			||||||
 | 
					from stable_baselines3.common.distributions import sum_independent_dims
 | 
				
			||||||
 | 
					from torch.distributions import Normal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Par_Strength(Enum):
 | 
				
			||||||
 | 
					    SCALAR = 'SCALAR'
 | 
				
			||||||
 | 
					    DIAG = 'DIAG'
 | 
				
			||||||
 | 
					    CONT_SCALAR = 'CONT_SCALAR'
 | 
				
			||||||
 | 
					    CONT_DIAG = 'CONT_DIAG'
 | 
				
			||||||
 | 
					    CONT_HYBRID = 'CONT_HYBRID'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EnforcePositiveType(Enum):
 | 
				
			||||||
 | 
					    # This need to be implemented in this ugly fashion,
 | 
				
			||||||
 | 
					    # because cloudpickle does not like more complex enums
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    NONE = 0
 | 
				
			||||||
 | 
					    SOFTPLUS = 1
 | 
				
			||||||
 | 
					    ABS = 2
 | 
				
			||||||
 | 
					    RELU = 3
 | 
				
			||||||
 | 
					    LOG = 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def apply(self, x):
 | 
				
			||||||
 | 
					        # aaaaaa
 | 
				
			||||||
 | 
					        return [nn.Identity(), nn.Softplus(beta=1, threshold=20), th.abs, nn.ReLU(inplace=False), th.log][self.value](x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rbf(l=1.0):
 | 
				
			||||||
 | 
					    return se(sig=1.0, l=l)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def se(sig=1.0, l=1.0):
 | 
				
			||||||
 | 
					    def func(xa, xb):
 | 
				
			||||||
 | 
					        sq = scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / -2*l
 | 
				
			||||||
 | 
					        return (sig**2)*np.exp(sq)
 | 
				
			||||||
 | 
					    return func
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Avaible_Kernel_Funcs(Enum):
 | 
				
			||||||
 | 
					    RBF = 0
 | 
				
			||||||
 | 
					    SE = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_func(self):
 | 
				
			||||||
 | 
					        # stil aaaaaaaa
 | 
				
			||||||
 | 
					        return [rbf, se][self.value]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cast_to_enum(inp, Class):
 | 
				
			||||||
 | 
					    if isinstance(inp, Enum):
 | 
				
			||||||
 | 
					        return inp
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return Class[inp]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cast_to_kernel(inp):
 | 
				
			||||||
 | 
					    if isinstance(inp, function):
 | 
				
			||||||
 | 
					        return inp
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        func, *pars = inp.split('_')
 | 
				
			||||||
 | 
					        return Avaible_Kernel_Funcs[func](*pars)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PCA_Distribution(SB3_Distribution):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        action_dim: int,
 | 
				
			||||||
 | 
					        par_strength: Par_Strength = Par_Strength.CONT_DIAG,
 | 
				
			||||||
 | 
					        kernel_func=rbf(),
 | 
				
			||||||
 | 
					        init_std: int = 1,
 | 
				
			||||||
 | 
					        window: int = 64,
 | 
				
			||||||
 | 
					        epsilon: float = 1e-6,
 | 
				
			||||||
 | 
					        use_sde: bool = False,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert use_sde == False, 'PCA with SDE is not implemented'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.action_dim = action_dim
 | 
				
			||||||
 | 
					        self.kernel_func = cast_to_kernel(kernel_func)
 | 
				
			||||||
 | 
					        self.par_strength = cast_to_enum(Par_Strength, par_strength)
 | 
				
			||||||
 | 
					        self.init_std = init_std
 | 
				
			||||||
 | 
					        self.window = window
 | 
				
			||||||
 | 
					        self.epsilon = epsilon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Premature optimization is the root of all evil
 | 
				
			||||||
 | 
					        self._build_conditioner()
 | 
				
			||||||
 | 
					        # *Optimizes it anyways*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def proba_distribution_net(self, latent_dim: int):
 | 
				
			||||||
 | 
					        mu_net = nn.Linear(latent_dim, self.action_dim)
 | 
				
			||||||
 | 
					        std_net = StdNet(latent_dim, self.action_dim,
 | 
				
			||||||
 | 
					                         self.init_std, self.par_strength)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return mu_net, std_net
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def proba_distribution(
 | 
				
			||||||
 | 
					            self, mean_actions: th.Tensor, std_actions: th.Tensor) -> SB3_Distribution:
 | 
				
			||||||
 | 
					        self.distribution = Normal(
 | 
				
			||||||
 | 
					            mean_actions, std_actions)
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def log_prob(self, actions: th.Tensor) -> th.Tensor:
 | 
				
			||||||
 | 
					        return sum_independent_dims(self.distribution.log_prob(actions))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def entropy(self) -> th.Tensor:
 | 
				
			||||||
 | 
					        return sum_independent_dims(self.distribution.entropy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sample(self, traj: th.Tensor) -> th.Tensor:
 | 
				
			||||||
 | 
					        pi_mean, pi_std = self.distribution.mean, self.distribution.scale,
 | 
				
			||||||
 | 
					        rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
 | 
				
			||||||
 | 
					        eta = self._get_rigged(pi_mean, pi_std,
 | 
				
			||||||
 | 
					                               rho_mean, rho_std)
 | 
				
			||||||
 | 
					        # reparameterization with rigged samples
 | 
				
			||||||
 | 
					        actions = pi_mean + pi_std * eta
 | 
				
			||||||
 | 
					        self.gaussian_actions = actions
 | 
				
			||||||
 | 
					        return actions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def is_contextual(self):
 | 
				
			||||||
 | 
					        return self.par_strength not in [Par_Strength.SCALAR, Par_Strength.DIAG]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_rigged(self, pi_mean, pi_std, rho_mean, rho_std):
 | 
				
			||||||
 | 
					        with th.no_grad():
 | 
				
			||||||
 | 
					            epsilon = th.Tensor(np.random.normal(0, 1, pi_mean.shape))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            Delta = rho_mean - pi_mean
 | 
				
			||||||
 | 
					            Pi_mu = 1 / pi_std
 | 
				
			||||||
 | 
					            Pi_sigma = rho_std / pi_std
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            eta = Pi_mu * Delta + Pi_sigma * epsilon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return eta.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _conditioning_engine(self, traj, pi_mean, pi_std):
 | 
				
			||||||
 | 
					        y_np = np.append(np.swapaxes(traj, -1, -2),
 | 
				
			||||||
 | 
					                         np.expand_dims(pi_mean, -1), -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with th.no_grad():
 | 
				
			||||||
 | 
					            conditioners = th.Tensor(self._adapt_conditioner(pi_std))
 | 
				
			||||||
 | 
					            y = th.Tensor(y_np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            S = th.cholesky_solve(self.Sig12.expand(
 | 
				
			||||||
 | 
					                conditioners.shape[:-1]).unsqueeze(-1), conditioners).squeeze(-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rho_mean = th.einsum('bai,bai->ba', S, y)
 | 
				
			||||||
 | 
					            rho_std = self.Sig22 - (S @ self.Sig12)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return rho_mean, rho_std
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _build_conditioner(self):
 | 
				
			||||||
 | 
					        # Precomputes the Cholesky decomp of the cov matrix to be used as a pseudoinverse.
 | 
				
			||||||
 | 
					        # Also precomputes some auxilary stuff for _adapt_conditioner.
 | 
				
			||||||
 | 
					        w = self.window
 | 
				
			||||||
 | 
					        Z = np.linspace(0, w, w+1).reshape(-1, 1)
 | 
				
			||||||
 | 
					        X = np.array([w]).reshape(-1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Sig11 = self.kernel_func(Z, Z)
 | 
				
			||||||
 | 
					        self.Sig12 = th.Tensor(self.kernel_func(Z, X)).squeeze(-1)
 | 
				
			||||||
 | 
					        self.Sig22 = th.Tensor(self.kernel_func(
 | 
				
			||||||
 | 
					            X, X)).squeeze(-1).squeeze(-1)
 | 
				
			||||||
 | 
					        self.conditioner = np.linalg.cholesky(Sig11)
 | 
				
			||||||
 | 
					        self.adapt_norm = np.linalg.norm(
 | 
				
			||||||
 | 
					            self.conditioner[-1, :][:-1], axis=-1)**2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _adapt_conditioner(self, pi_std):
 | 
				
			||||||
 | 
					        # We can not actually precompute the cov inverse completely,
 | 
				
			||||||
 | 
					        # since it also depends on the current policies sigma.
 | 
				
			||||||
 | 
					        # But, because of the way the Cholesky Decomp works,
 | 
				
			||||||
 | 
					        # we can use the precomputed L (conditioner)
 | 
				
			||||||
 | 
					        # (which is computed by an efficient LAPACK implementation)
 | 
				
			||||||
 | 
					        # and adapt it for our new k(x_w+1,x_w+1) value (in python)
 | 
				
			||||||
 | 
					        # (Which is dependent on pi)
 | 
				
			||||||
 | 
					        # S_{ij} = \frac{1}{D_j} \left( A_{ij} - \sum_{k=1}^{j-1} S_{ik} S_{jk} D_k \right), \qquad\text{for } i>j
 | 
				
			||||||
 | 
					        # https://martin-thoma.com/images/2012/07/cholesky-zerlegung-numerik.png
 | 
				
			||||||
 | 
					        # This way conditioning of the GP can be done in O(dim(A)) time.
 | 
				
			||||||
 | 
					        if not self.is_contextual():
 | 
				
			||||||
 | 
					            # safe inplace
 | 
				
			||||||
 | 
					            self.conditioner[-1, -
 | 
				
			||||||
 | 
					                             1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)
 | 
				
			||||||
 | 
					            return np.expand_dims(np.expand_dims(self.conditioner, 0), 0)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            conditioner = np.zeros(
 | 
				
			||||||
 | 
					                (pi_std.shape[0], pi_std.shape[1]) + self.conditioner.shape)
 | 
				
			||||||
 | 
					            conditioner[:, :] = self.conditioner
 | 
				
			||||||
 | 
					            conditioner[:, :, -1, -
 | 
				
			||||||
 | 
					                        1] = np.sqrt(pi_std**2 + self.Sig22 - self.adapt_norm)
 | 
				
			||||||
 | 
					            return conditioner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def mode(self) -> th.Tensor:
 | 
				
			||||||
 | 
					        return self.distribution.mean
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def actions_from_params(
 | 
				
			||||||
 | 
					        self, mean: th.Tensor, std: th.Tensor, deterministic: bool = False
 | 
				
			||||||
 | 
					    ) -> th.Tensor:
 | 
				
			||||||
 | 
					        self.proba_distribution(mean, std)
 | 
				
			||||||
 | 
					        return self.get_actions(deterministic=deterministic)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def log_prob_from_params(self, mean: th.Tensor, std: th.Tensor):
 | 
				
			||||||
 | 
					        actions = self.actions_from_params(mean, std)
 | 
				
			||||||
 | 
					        log_prob = self.log_prob(actions)
 | 
				
			||||||
 | 
					        return actions, log_prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def print_info(self, traj: th.Tensor):
 | 
				
			||||||
 | 
					        pi_mean, pi_std = self.distribution.mean, self.distribution.scale,
 | 
				
			||||||
 | 
					        rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
 | 
				
			||||||
 | 
					        eta = self._get_rigged(pi_mean, pi_std,
 | 
				
			||||||
 | 
					                               rho_mean, rho_std)
 | 
				
			||||||
 | 
					        print('pi  ~ N('+str(pi_mean)+','+str(pi_std)+')')
 | 
				
			||||||
 | 
					        print('rho ~ N('+str(rho_mean)+','+str(rho_std)+')')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StdNet(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, latent_dim: int, action_dim: int, std_init: float, par_strength: bool, epsilon: float):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.action_dim = action_dim
 | 
				
			||||||
 | 
					        self.latent_dim = latent_dim
 | 
				
			||||||
 | 
					        self.std_init = std_init
 | 
				
			||||||
 | 
					        self.par_strength = par_strength
 | 
				
			||||||
 | 
					        self.enforce_positive_type = EnforcePositiveType.SOFTPLUS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.epsilon = epsilon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.par_strength == Par_Strength.SCALAR:
 | 
				
			||||||
 | 
					            self.param = nn.Parameter(
 | 
				
			||||||
 | 
					                th.Tensor([std_init]), requires_grad=True)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.DIAG:
 | 
				
			||||||
 | 
					            self.param = nn.Parameter(
 | 
				
			||||||
 | 
					                th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_SCALAR:
 | 
				
			||||||
 | 
					            self.net = nn.Linear(latent_dim, 1)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_HYBRID:
 | 
				
			||||||
 | 
					            self.net = nn.Linear(latent_dim, 1)
 | 
				
			||||||
 | 
					            self.param = nn.Parameter(
 | 
				
			||||||
 | 
					                th.Tensor(th.ones(action_dim)*std_init), requires_grad=True)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_DIAG:
 | 
				
			||||||
 | 
					            self.net = nn.Linear(latent_dim, self.action_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: th.Tensor) -> th.Tensor:
 | 
				
			||||||
 | 
					        if self.par_strength == Par_Strength.SCALAR:
 | 
				
			||||||
 | 
					            return self._ensure_positive_func(
 | 
				
			||||||
 | 
					                th.ones(self.action_dim) * self.param[0])
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.DIAG:
 | 
				
			||||||
 | 
					            return self._ensure_positive_func(self.param)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_SCALAR:
 | 
				
			||||||
 | 
					            cont = self.net(x)
 | 
				
			||||||
 | 
					            diag_chol = th.ones(self.action_dim) * cont * self.std_init
 | 
				
			||||||
 | 
					            return self._ensure_positive_func(diag_chol)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_HYBRID:
 | 
				
			||||||
 | 
					            cont = self.net(x)
 | 
				
			||||||
 | 
					            return self._ensure_positive_func(self.param * cont)
 | 
				
			||||||
 | 
					        elif self.par_strength == Par_Strength.CONT_DIAG:
 | 
				
			||||||
 | 
					            cont = self.net(x)
 | 
				
			||||||
 | 
					            diag_chol = cont * self.std_init
 | 
				
			||||||
 | 
					            return self._ensure_positive_func(diag_chol)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        raise Exception()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _ensure_positive_func(self, x):
 | 
				
			||||||
 | 
					        return self.enforce_positive_type.apply(x) + self.epsilon
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def string(self):
 | 
				
			||||||
 | 
					        return '<StdNet />'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test():
 | 
				
			||||||
 | 
					    mu = th.Tensor([[0.0, 0.0]])
 | 
				
			||||||
 | 
					    sigma = th.Tensor([[0.9, 0.1]])
 | 
				
			||||||
 | 
					    traj = th.Tensor([[[-1.0, -1.0], [-0.4, -0.4], [0.3, 0.3]]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    d = PCA_Distribution(2, window=3)
 | 
				
			||||||
 | 
					    d.proba_distribution(mu, sigma)
 | 
				
			||||||
 | 
					    d.print_info(traj)
 | 
				
			||||||
 | 
					    print(d.sample(traj))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    test()
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user