Compare commits
	
		
			3 Commits
		
	
	
		
			44eb3335ff
			...
			de2b9a10d6
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| de2b9a10d6 | |||
| 9fb0014a99 | |||
| 8e991ae05b | 
@ -65,9 +65,7 @@ pytest tests/test_projections.py
 | 
				
			|||||||
*Note*: The test suite verifies:
 | 
					*Note*: The test suite verifies:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
 | 
					1. All projections run without errors and maintain basic properties (shapes, positive definiteness)
 | 
				
			||||||
2. KL bounds are actually (approximately) met for:
 | 
					2. KL bounds are actually (approximately) met for true KL projection (both diagonal and full covariance)
 | 
				
			||||||
   - KL projection (both diagonal and full covariance)
 | 
					 | 
				
			||||||
   - Wasserstein projection (diagonal covariance only)
 | 
					 | 
				
			||||||
3. Gradients can be computed through all projections:
 | 
					3. Gradients can be computed through all projections:
 | 
				
			||||||
   - Both through projection operation and trust region loss
 | 
					   - Both through projection operation and trust region loss
 | 
				
			||||||
   - Gradients have correct shapes and are finite
 | 
					   - Gradients have correct shapes and are finite
 | 
				
			||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
import jax.numpy as jnp
 | 
					import jax.numpy as jnp
 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
from typing import Dict
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					import jax
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FrobeniusProjection(BaseProjection):
 | 
					class FrobeniusProjection(BaseProjection):
 | 
				
			||||||
    def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
 | 
					    def __init__(self, trust_region_coeff: float = 1.0, mean_bound: float = 0.01,
 | 
				
			||||||
@ -64,12 +65,26 @@ class FrobeniusProjection(BaseProjection):
 | 
				
			|||||||
        old_mean, old_cov = q
 | 
					        old_mean, old_cov = q
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.scale_prec:
 | 
					        if self.scale_prec:
 | 
				
			||||||
            prec_old = jnp.linalg.inv(old_cov)
 | 
					            # Mahalanobis distance for mean
 | 
				
			||||||
            mean_part = jnp.sum(jnp.matmul(mean - old_mean, prec_old) * (mean - old_mean), axis=-1)
 | 
					            diff = mean - old_mean
 | 
				
			||||||
            cov_part = jnp.sum(prec_old * cov, axis=(-2, -1)) - jnp.log(jnp.linalg.det(jnp.matmul(prec_old, cov))) - mean.shape[-1]
 | 
					            if old_cov.ndim == mean.ndim:  # diagonal case
 | 
				
			||||||
 | 
					                mean_part = jnp.sum(jnp.square(diff / old_cov), axis=-1)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                solved = jax.scipy.linalg.solve_triangular(
 | 
				
			||||||
 | 
					                    old_cov, diff[..., None], lower=True
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                mean_part = jnp.sum(jnp.square(solved.squeeze(-1)), axis=-1)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
 | 
					            mean_part = jnp.sum(jnp.square(mean - old_mean), axis=-1)
 | 
				
			||||||
            cov_part = jnp.sum(jnp.square(cov - old_cov), axis=(-2, -1))
 | 
					
 | 
				
			||||||
 | 
					        # Frobenius norm for covariance
 | 
				
			||||||
 | 
					        if cov.ndim == mean.ndim:  # diagonal case
 | 
				
			||||||
 | 
					            diff = old_cov - cov
 | 
				
			||||||
 | 
					            cov_part = jnp.sum(jnp.square(diff), axis=-1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            diff = jnp.matmul(old_cov, jnp.swapaxes(old_cov, -1, -2)) - \
 | 
				
			||||||
 | 
					                   jnp.matmul(cov, jnp.swapaxes(cov, -1, -2))
 | 
				
			||||||
 | 
					            cov_part = jnp.sum(jnp.square(diff), axis=(-2, -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return mean_part, cov_part
 | 
					        return mean_part, cov_part
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -83,22 +98,20 @@ class FrobeniusProjection(BaseProjection):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, 
 | 
					    def _cov_projection(self, cov: jnp.ndarray, old_cov: jnp.ndarray, 
 | 
				
			||||||
                       cov_part: jnp.ndarray) -> jnp.ndarray:
 | 
					                       cov_part: jnp.ndarray) -> jnp.ndarray:
 | 
				
			||||||
        batch_shape = cov.shape[:-2]
 | 
					        batch_shape = cov.shape[:-2] if cov.ndim > 2 else cov.shape[:-1]
 | 
				
			||||||
        cov_mask = cov_part > self.cov_bound
 | 
					        cov_mask = cov_part > self.cov_bound
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        eta = jnp.ones(batch_shape, dtype=cov.dtype)
 | 
					        eta = jnp.ones(batch_shape, dtype=cov.dtype)
 | 
				
			||||||
        eta = jnp.where(cov_mask, 
 | 
					        eta = jnp.where(cov_mask, 
 | 
				
			||||||
                       jnp.sqrt(cov_part / self.cov_bound) - 1.,
 | 
					                        jnp.sqrt(cov_part / self.cov_bound) - 1.,
 | 
				
			||||||
                       eta)
 | 
					                        eta)
 | 
				
			||||||
        eta = jnp.maximum(-eta, eta)
 | 
					        eta = jnp.maximum(-eta, eta)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.full_cov:
 | 
					        if self.full_cov:
 | 
				
			||||||
            new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
 | 
					            new_cov = (cov + jnp.einsum('...,...ij->...ij', eta, old_cov)) / \
 | 
				
			||||||
 | 
					                      (1. + eta + 1e-16)[..., None, None]
 | 
				
			||||||
 | 
					            proj_cov = jnp.where(cov_mask[..., None, None], new_cov, cov)
 | 
				
			||||||
 | 
					            return jnp.linalg.cholesky(proj_cov)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # For diagonal case, simple broadcasting
 | 
					 | 
				
			||||||
            new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
 | 
					            new_cov = (cov + eta[..., None] * old_cov) / (1. + eta + 1e-16)[..., None]
 | 
				
			||||||
            
 | 
					            return jnp.where(cov_mask[..., None], jnp.sqrt(new_cov), cov)
 | 
				
			||||||
        proj_cov = jnp.where(cov_mask[..., None] if not self.full_cov else cov_mask[..., None, None], 
 | 
					 | 
				
			||||||
                             new_cov, cov)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return proj_cov
 | 
					 | 
				
			||||||
@ -13,6 +13,24 @@ from .exception_projection import makeExceptionProjection
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
MAX_EVAL = 1000
 | 
					MAX_EVAL = 1000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Cache for projection operators
 | 
				
			||||||
 | 
					_diag_proj_op = None
 | 
				
			||||||
 | 
					_full_proj_op = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_diag_proj_op(batch_shape, dim):
 | 
				
			||||||
 | 
					    global _diag_proj_op
 | 
				
			||||||
 | 
					    if _diag_proj_op is None:
 | 
				
			||||||
 | 
					        _diag_proj_op = cpp_projection.BatchedDiagCovOnlyProjection(
 | 
				
			||||||
 | 
					            batch_shape, dim, max_eval=MAX_EVAL)
 | 
				
			||||||
 | 
					    return _diag_proj_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_full_proj_op(batch_shape, dim):
 | 
				
			||||||
 | 
					    global _full_proj_op
 | 
				
			||||||
 | 
					    if _full_proj_op is None:
 | 
				
			||||||
 | 
					        _full_proj_op = cpp_projection.BatchedCovOnlyProjection(
 | 
				
			||||||
 | 
					            batch_shape, dim, max_eval=MAX_EVAL)
 | 
				
			||||||
 | 
					    return _full_proj_op
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class KLProjection(BaseProjection):
 | 
					class KLProjection(BaseProjection):
 | 
				
			||||||
    """KL divergence-based projection for Gaussian policies.
 | 
					    """KL divergence-based projection for Gaussian policies.
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
@ -166,6 +184,49 @@ class KLProjection(BaseProjection):
 | 
				
			|||||||
            if key not in policy_params or key not in old_policy_params:
 | 
					            if key not in policy_params or key not in old_policy_params:
 | 
				
			||||||
                raise KeyError(f"Missing required key '{key}' in policy parameters")
 | 
					                raise KeyError(f"Missing required key '{key}' in policy parameters")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@partial(jax.custom_vjp, nondiff_argnums=(2,))
 | 
				
			||||||
 | 
					def project_diag_covariance(cov, old_cov, eps_cov):
 | 
				
			||||||
 | 
					    """JAX wrapper for C++ diagonal covariance projection"""
 | 
				
			||||||
 | 
					    batch_shape = cov.shape[0]
 | 
				
			||||||
 | 
					    dim = cov.shape[-1]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    cov_np = np.asarray(cov)
 | 
				
			||||||
 | 
					    old_cov_np = np.asarray(old_cov)
 | 
				
			||||||
 | 
					    eps = eps_cov * np.ones(batch_shape, dtype=old_cov_np.dtype)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    p_op = _get_diag_proj_op(batch_shape, dim)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        proj_cov = p_op.forward(eps, old_cov_np, cov_np)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					        proj_cov = cov_np  # Return input on failure
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					    return jnp.array(proj_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def project_diag_covariance_fwd(cov, old_cov, eps_cov):
 | 
				
			||||||
 | 
					    y = project_diag_covariance(cov, old_cov, eps_cov)
 | 
				
			||||||
 | 
					    return y, (cov, old_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def project_diag_covariance_bwd(eps_cov, res, g):
 | 
				
			||||||
 | 
					    cov, old_cov = res
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Convert to numpy for C++ backward pass
 | 
				
			||||||
 | 
					    g_np = np.asarray(g)
 | 
				
			||||||
 | 
					    batch_shape = g_np.shape[0]
 | 
				
			||||||
 | 
					    dim = g_np.shape[-1]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Get C++ projection operator
 | 
				
			||||||
 | 
					    p_op = _get_diag_proj_op(batch_shape, dim)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Run C++ backward pass
 | 
				
			||||||
 | 
					    grad_cov = p_op.backward(g_np)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Convert back to JAX array
 | 
				
			||||||
 | 
					    return jnp.array(grad_cov), None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Register VJP rule for diagonal covariance projection
 | 
				
			||||||
 | 
					project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@partial(jax.custom_vjp, nondiff_argnums=(3,))
 | 
					@partial(jax.custom_vjp, nondiff_argnums=(3,))
 | 
				
			||||||
def project_full_covariance(cov, chol, old_chol, eps_cov):
 | 
					def project_full_covariance(cov, chol, old_chol, eps_cov):
 | 
				
			||||||
    """JAX wrapper for C++ full covariance projection"""
 | 
					    """JAX wrapper for C++ full covariance projection"""
 | 
				
			||||||
@ -179,7 +240,7 @@ def project_full_covariance(cov, chol, old_chol, eps_cov):
 | 
				
			|||||||
        eps = eps_cov * np.ones(batch_shape)
 | 
					        eps = eps_cov * np.ones(batch_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Create C++ projection operator directly
 | 
					        # Create C++ projection operator directly
 | 
				
			||||||
        p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
 | 
					        p_op = _get_full_proj_op(batch_shape, dim)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Run C++ projection
 | 
					        # Run C++ projection
 | 
				
			||||||
        proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
 | 
					        proj_cov = p_op.forward(eps, old_chol_np, chol_np, cov_np)
 | 
				
			||||||
@ -203,7 +264,7 @@ def project_full_covariance_bwd(eps_cov, res, g):
 | 
				
			|||||||
    dim = g_np.shape[-1]
 | 
					    dim = g_np.shape[-1]
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Get C++ projection operator
 | 
					    # Get C++ projection operator
 | 
				
			||||||
    p_op = cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
 | 
					    p_op = _get_full_proj_op(batch_shape, dim)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Run C++ backward pass
 | 
					    # Run C++ backward pass
 | 
				
			||||||
    grad_cov = p_op.backward(g_np)
 | 
					    grad_cov = p_op.backward(g_np)
 | 
				
			||||||
@ -214,48 +275,5 @@ def project_full_covariance_bwd(eps_cov, res, g):
 | 
				
			|||||||
# Register VJP rule for full covariance projection
 | 
					# Register VJP rule for full covariance projection
 | 
				
			||||||
project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
 | 
					project_full_covariance.defvjp(project_full_covariance_fwd, project_full_covariance_bwd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@partial(jax.custom_vjp, nondiff_argnums=(2,))
 | 
					 | 
				
			||||||
def project_diag_covariance(cov, old_cov, eps_cov):
 | 
					 | 
				
			||||||
    """JAX wrapper for C++ diagonal covariance projection"""
 | 
					 | 
				
			||||||
    # Convert JAX arrays to numpy for C++ function
 | 
					 | 
				
			||||||
    cov_np = np.asarray(cov)
 | 
					 | 
				
			||||||
    old_cov_np = np.asarray(old_cov)
 | 
					 | 
				
			||||||
    batch_shape = cov_np.shape[0]
 | 
					 | 
				
			||||||
    dim = cov_np.shape[-1]
 | 
					 | 
				
			||||||
    eps = eps_cov * np.ones(batch_shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Create C++ projection operator directly
 | 
					 | 
				
			||||||
    p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Run C++ projection
 | 
					 | 
				
			||||||
    proj_cov = p_op.forward(eps, old_cov_np, cov_np)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Convert back to JAX array
 | 
					 | 
				
			||||||
    return jnp.array(proj_cov)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def project_diag_covariance_fwd(cov, old_cov, eps_cov):
 | 
					 | 
				
			||||||
    y = project_diag_covariance(cov, old_cov, eps_cov)
 | 
					 | 
				
			||||||
    return y, (cov, old_cov)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def project_diag_covariance_bwd(eps_cov, res, g):
 | 
					 | 
				
			||||||
    cov, old_cov = res
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Convert to numpy for C++ backward pass
 | 
					 | 
				
			||||||
    g_np = np.asarray(g)
 | 
					 | 
				
			||||||
    batch_shape = g_np.shape[0]
 | 
					 | 
				
			||||||
    dim = g_np.shape[-1]
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Get C++ projection operator
 | 
					 | 
				
			||||||
    p_op = cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=MAX_EVAL)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Run C++ backward pass
 | 
					 | 
				
			||||||
    grad_cov = p_op.backward(g_np)
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    # Convert back to JAX array
 | 
					 | 
				
			||||||
    return jnp.array(grad_cov), None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Register VJP rule for diagonal covariance projection
 | 
					 | 
				
			||||||
project_diag_covariance.defvjp(project_diag_covariance_fwd, project_diag_covariance_bwd)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if not cpp_projection_available:
 | 
					if not cpp_projection_available:
 | 
				
			||||||
    KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
 | 
					    KLProjection = makeExceptionProjection("ITPAL (C++ library) is not available. Please install the C++ library to use this projection.")
 | 
				
			||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
import jax.numpy as jnp
 | 
					import jax.numpy as jnp
 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
from typing import Dict, Tuple
 | 
					from typing import Dict, Tuple
 | 
				
			||||||
 | 
					import jax
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
 | 
					def scale_tril_to_sqrt(scale_tril: jnp.ndarray) -> jnp.ndarray:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -22,7 +23,8 @@ class WassersteinProjection(BaseProjection):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def project(self, policy_params: Dict[str, jnp.ndarray], 
 | 
					    def project(self, policy_params: Dict[str, jnp.ndarray], 
 | 
				
			||||||
                old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
 | 
					                old_policy_params: Dict[str, jnp.ndarray]) -> Dict[str, jnp.ndarray]:
 | 
				
			||||||
        assert not self.full_cov, "Wasserstein projection only supports diagonal covariance"
 | 
					        if self.full_cov:
 | 
				
			||||||
 | 
					            print("Warning: Wasserstein projection with full covariance is wip, we recommend using diagonal covariance instead.")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        mean = policy_params["loc"]  # shape: (batch_size, dim)
 | 
					        mean = policy_params["loc"]  # shape: (batch_size, dim)
 | 
				
			||||||
        old_mean = old_policy_params["loc"]
 | 
					        old_mean = old_policy_params["loc"]
 | 
				
			||||||
@ -73,27 +75,47 @@ class WassersteinProjection(BaseProjection):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, 
 | 
					    def _scale_projection(self, scale: jnp.ndarray, old_scale: jnp.ndarray, 
 | 
				
			||||||
                         scale_part: jnp.ndarray) -> jnp.ndarray:
 | 
					                         scale_part: jnp.ndarray) -> jnp.ndarray:
 | 
				
			||||||
        """Project scale parameters (standard deviations for diagonal case)"""
 | 
					        """Project scale parameters using multiplicative update.
 | 
				
			||||||
        diff = scale - old_scale
 | 
					 | 
				
			||||||
        norm = jnp.sqrt(scale_part)
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if scale.ndim == 2:  # Batched scale
 | 
					        Args:
 | 
				
			||||||
            norm = norm[..., None]
 | 
					            scale: Current scale/sqrt of covariance
 | 
				
			||||||
 | 
					            old_scale: Previous scale/sqrt of covariance
 | 
				
			||||||
 | 
					            scale_part: W2 distance between scales
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        return jnp.where(norm > self.cov_bound,
 | 
					        Returns:
 | 
				
			||||||
                        old_scale + diff * self.cov_bound / norm,
 | 
					            Projected scale that satisfies the trust region constraint
 | 
				
			||||||
                        scale)
 | 
					        """
 | 
				
			||||||
 | 
					        # Check if projection needed
 | 
				
			||||||
 | 
					        cov_mask = scale_part > self.cov_bound
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Compute eta (multiplier for the update)
 | 
				
			||||||
 | 
					        batch_shape = scale.shape[:-2] if scale.ndim > 2 else scale.shape[:-1]
 | 
				
			||||||
 | 
					        eta = jnp.ones(batch_shape, dtype=scale.dtype)
 | 
				
			||||||
 | 
					        eta = jnp.where(cov_mask,
 | 
				
			||||||
 | 
					                        jnp.sqrt(scale_part / self.cov_bound) - 1.,
 | 
				
			||||||
 | 
					                        eta)
 | 
				
			||||||
 | 
					        eta = jnp.maximum(-eta, eta)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Multiplicative update with correct broadcasting
 | 
				
			||||||
 | 
					        if scale.ndim > 2:  # Full covariance case
 | 
				
			||||||
 | 
					            new_scale = (scale + jnp.einsum('...,...ij->...ij', eta, old_scale)) / \
 | 
				
			||||||
 | 
					                       (1. + eta + 1e-16)[..., None, None]
 | 
				
			||||||
 | 
					            mask = cov_mask[..., None, None]
 | 
				
			||||||
 | 
					        else:  # Diagonal case
 | 
				
			||||||
 | 
					            new_scale = (scale + eta[..., None] * old_scale) / \
 | 
				
			||||||
 | 
					                       (1. + eta + 1e-16)[..., None]
 | 
				
			||||||
 | 
					            mask = cov_mask[..., None]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        return jnp.where(mask, new_scale, scale)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _gaussian_wasserstein(self, p, q):
 | 
					    def _gaussian_wasserstein(self, p, q):
 | 
				
			||||||
        mean, scale = p
 | 
					        mean, scale = p
 | 
				
			||||||
        mean_other, scale_other = q
 | 
					        mean_other, scale_other = q
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Keep batch dimension by only summing over feature dimension
 | 
					        # Euclidean distance for mean part (we're in diagonal case)
 | 
				
			||||||
        mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)  # -> (batch_size,)
 | 
					        mean_part = jnp.sum(jnp.square(mean - mean_other), axis=-1)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if scale.ndim == mean.ndim:  # Batched scale
 | 
					        # Standard W2 objective for covariance (diagonal case)
 | 
				
			||||||
            cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
 | 
					        cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale, axis=-1)
 | 
				
			||||||
        else:  # Non-contextual scale (single scale for all batches)
 | 
					 | 
				
			||||||
            cov_part = jnp.sum(scale_other**2 + scale**2 - 2 * scale_other * scale)
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        return mean_part, cov_part
 | 
					        return mean_part, cov_part
 | 
				
			||||||
@ -94,8 +94,8 @@ def test_diagonal_projection(ProjectionClass, needs_cpp, gaussian_params):
 | 
				
			|||||||
    assert jnp.all(jnp.isfinite(proj_params["scale"]))
 | 
					    assert jnp.all(jnp.isfinite(proj_params["scale"]))
 | 
				
			||||||
    assert jnp.all(proj_params["scale"] > 0)
 | 
					    assert jnp.all(proj_params["scale"] > 0)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Only check KL bounds for KL projection (and W2, which should approx hold as well)
 | 
					    # Only check KL bounds for KL projection
 | 
				
			||||||
    if ProjectionClass in [KLProjection, WassersteinProjection]:
 | 
					    if ProjectionClass in [KLProjection]:
 | 
				
			||||||
        kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
 | 
					        kl = compute_gaussian_kl(proj_params, gaussian_params["old_params"])
 | 
				
			||||||
        max_kl = (mean_bound + cov_bound) * 1.1  # Allow 10% margin
 | 
					        max_kl = (mean_bound + cov_bound) * 1.1  # Allow 10% margin
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user