diff --git a/perf_tests/perf_test_frobenius.py b/perf_tests/perf_test_frobenius.py new file mode 100644 index 0000000..0a75726 --- /dev/null +++ b/perf_tests/perf_test_frobenius.py @@ -0,0 +1,50 @@ +import jax +import jax.numpy as jnp +import time +from itpal_jax import FrobeniusProjection + +def generate_params(key, batch_size, dim): + keys = jax.random.split(key, 2) + return { + "loc": jax.random.normal(keys[0], (batch_size, dim)), + "scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim))) + } + +def main(): + # Test parameters + batch_size = 32 + dim = 8 + n_iterations = 1000 + + # Initialize projector + proj = FrobeniusProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True) + + # Compile function + proj_fn = lambda p, op: proj.project(p, op) + proj_fn = jax.jit(proj_fn) + + # Generate initial key + key = jax.random.PRNGKey(0) + + # Warmup + for _ in range(10): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + + # Time projections + start_time = time.time() + for _ in range(n_iterations): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + end_time = time.time() + + print(f"Frobenius Projection:") + print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms") + print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/perf_tests/perf_test_kl.py b/perf_tests/perf_test_kl.py new file mode 100644 index 0000000..66d05a5 --- /dev/null +++ b/perf_tests/perf_test_kl.py @@ -0,0 +1,49 @@ +import jax +import jax.numpy as jnp +import time +from itpal_jax import KLProjection + +def generate_params(key, batch_size, dim): + keys = jax.random.split(key, 2) + return { + "loc": jax.random.normal(keys[0], (batch_size, dim)), + "scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim))) + } + +def main(): + # Test parameters + batch_size = 32 + dim = 8 + n_iterations = 1000 + + # Initialize projector + proj = KLProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True) + + # No JIT for KL projection since it uses C++ backend + proj_fn = lambda p, op: proj.project(p, op) + + # Generate initial key + key = jax.random.PRNGKey(0) + + # Warmup + for _ in range(10): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + + # Time projections + start_time = time.time() + for _ in range(n_iterations): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + end_time = time.time() + + print(f"KL Projection:") + print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms") + print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/perf_tests/perf_test_wasserstein.py b/perf_tests/perf_test_wasserstein.py new file mode 100644 index 0000000..16d00a6 --- /dev/null +++ b/perf_tests/perf_test_wasserstein.py @@ -0,0 +1,50 @@ +import jax +import jax.numpy as jnp +import time +from itpal_jax import WassersteinProjection + +def generate_params(key, batch_size, dim): + keys = jax.random.split(key, 2) + return { + "loc": jax.random.normal(keys[0], (batch_size, dim)), + "scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim))) + } + +def main(): + # Test parameters + batch_size = 32 + dim = 8 + n_iterations = 1000 + + # Initialize projector + proj = WassersteinProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True) + + # Compile function + proj_fn = lambda p, op: proj.project(p, op) + proj_fn = jax.jit(proj_fn) + + # Generate initial key + key = jax.random.PRNGKey(0) + + # Warmup + for _ in range(10): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + + # Time projections + start_time = time.time() + for _ in range(n_iterations): + key, subkey1, subkey2 = jax.random.split(key, 3) + params = generate_params(subkey1, batch_size, dim) + old_params = generate_params(subkey2, batch_size, dim) + proj_fn(params, old_params) + end_time = time.time() + + print(f"Wasserstein Projection:") + print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms") + print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s") + +if __name__ == "__main__": + main() \ No newline at end of file