itpal_jax/perf_tests/perf_test_kl.py

49 lines
1.5 KiB
Python
Raw Normal View History

2025-01-07 18:24:41 +01:00
import jax
import jax.numpy as jnp
import time
from itpal_jax import KLProjection
def generate_params(key, batch_size, dim):
keys = jax.random.split(key, 2)
return {
"loc": jax.random.normal(keys[0], (batch_size, dim)),
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
}
def main():
# Test parameters
batch_size = 32
dim = 8
n_iterations = 1000
# Initialize projector
proj = KLProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
# No JIT for KL projection since it uses C++ backend
proj_fn = lambda p, op: proj.project(p, op)
# Generate initial key
key = jax.random.PRNGKey(0)
# Warmup
for _ in range(10):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
# Time projections
start_time = time.time()
for _ in range(n_iterations):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
end_time = time.time()
print(f"KL Projection:")
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
if __name__ == "__main__":
main()