import jax import jax.numpy as jnp import time from itpal_jax import WassersteinProjection def generate_params(key, batch_size, dim): keys = jax.random.split(key, 2) return { "loc": jax.random.normal(keys[0], (batch_size, dim)), "scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim))) } def main(): # Test parameters batch_size = 32 dim = 8 n_iterations = 1000 # Initialize projector proj = WassersteinProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True) # Compile function proj_fn = lambda p, op: proj.project(p, op) proj_fn = jax.jit(proj_fn) # Generate initial key key = jax.random.PRNGKey(0) # Warmup for _ in range(10): key, subkey1, subkey2 = jax.random.split(key, 3) params = generate_params(subkey1, batch_size, dim) old_params = generate_params(subkey2, batch_size, dim) proj_fn(params, old_params) # Time projections start_time = time.time() for _ in range(n_iterations): key, subkey1, subkey2 = jax.random.split(key, 3) params = generate_params(subkey1, batch_size, dim) old_params = generate_params(subkey2, batch_size, dim) proj_fn(params, old_params) end_time = time.time() print(f"Wasserstein Projection:") print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms") print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s") if __name__ == "__main__": main()