Perf tests
This commit is contained in:
parent
404320c5cc
commit
1096dbd848
50
perf_tests/perf_test_frobenius.py
Normal file
50
perf_tests/perf_test_frobenius.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import time
|
||||||
|
from itpal_jax import FrobeniusProjection
|
||||||
|
|
||||||
|
def generate_params(key, batch_size, dim):
|
||||||
|
keys = jax.random.split(key, 2)
|
||||||
|
return {
|
||||||
|
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
||||||
|
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Test parameters
|
||||||
|
batch_size = 32
|
||||||
|
dim = 8
|
||||||
|
n_iterations = 1000
|
||||||
|
|
||||||
|
# Initialize projector
|
||||||
|
proj = FrobeniusProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
||||||
|
|
||||||
|
# Compile function
|
||||||
|
proj_fn = lambda p, op: proj.project(p, op)
|
||||||
|
proj_fn = jax.jit(proj_fn)
|
||||||
|
|
||||||
|
# Generate initial key
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(10):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
|
||||||
|
# Time projections
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(n_iterations):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Frobenius Projection:")
|
||||||
|
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
||||||
|
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
49
perf_tests/perf_test_kl.py
Normal file
49
perf_tests/perf_test_kl.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import time
|
||||||
|
from itpal_jax import KLProjection
|
||||||
|
|
||||||
|
def generate_params(key, batch_size, dim):
|
||||||
|
keys = jax.random.split(key, 2)
|
||||||
|
return {
|
||||||
|
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
||||||
|
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Test parameters
|
||||||
|
batch_size = 32
|
||||||
|
dim = 8
|
||||||
|
n_iterations = 1000
|
||||||
|
|
||||||
|
# Initialize projector
|
||||||
|
proj = KLProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
||||||
|
|
||||||
|
# No JIT for KL projection since it uses C++ backend
|
||||||
|
proj_fn = lambda p, op: proj.project(p, op)
|
||||||
|
|
||||||
|
# Generate initial key
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(10):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
|
||||||
|
# Time projections
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(n_iterations):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"KL Projection:")
|
||||||
|
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
||||||
|
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
50
perf_tests/perf_test_wasserstein.py
Normal file
50
perf_tests/perf_test_wasserstein.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import time
|
||||||
|
from itpal_jax import WassersteinProjection
|
||||||
|
|
||||||
|
def generate_params(key, batch_size, dim):
|
||||||
|
keys = jax.random.split(key, 2)
|
||||||
|
return {
|
||||||
|
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
||||||
|
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Test parameters
|
||||||
|
batch_size = 32
|
||||||
|
dim = 8
|
||||||
|
n_iterations = 1000
|
||||||
|
|
||||||
|
# Initialize projector
|
||||||
|
proj = WassersteinProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
||||||
|
|
||||||
|
# Compile function
|
||||||
|
proj_fn = lambda p, op: proj.project(p, op)
|
||||||
|
proj_fn = jax.jit(proj_fn)
|
||||||
|
|
||||||
|
# Generate initial key
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(10):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
|
||||||
|
# Time projections
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(n_iterations):
|
||||||
|
key, subkey1, subkey2 = jax.random.split(key, 3)
|
||||||
|
params = generate_params(subkey1, batch_size, dim)
|
||||||
|
old_params = generate_params(subkey2, batch_size, dim)
|
||||||
|
proj_fn(params, old_params)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Wasserstein Projection:")
|
||||||
|
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
||||||
|
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user