Compare commits

..

No commits in common. "1096dbd8480d51caca555159741038237f1164c1" and "4d6ed9b3ace01d28f52189c33c23595621d24c6d" have entirely different histories.

4 changed files with 18 additions and 169 deletions

View File

@ -160,28 +160,26 @@ class KLProjection(BaseProjection):
old_cov = old_scale_or_tril ** 2 old_cov = old_scale_or_tril ** 2
mask = cov_part > self.cov_bound mask = cov_part > self.cov_bound
proj_scale_or_tril = scale_or_tril # Start with original scale
if mask.any(): # Always compute both branches and use matrix operations to select
if self.full_cov: if self.full_cov:
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound) proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1))) is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril) valid_mask = mask & ~is_invalid
mask = mask & ~is_invalid
chol = jnp.linalg.cholesky(proj_cov) # Compute cholesky for all, let matrix ops handle selection
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril) chol = jnp.linalg.cholesky(proj_cov)
else: mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype)
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound) return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
mask = mask & ~is_invalid
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
else: else:
proj_scale_or_tril = scale_or_tril proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
return proj_scale_or_tril jnp.isinf(proj_cov.mean(axis=-1)) |
(proj_cov.min(axis=-1) < 0))
valid_mask = mask & ~is_invalid
mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype)
return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril
def _validate_inputs(self, policy_params, old_policy_params): def _validate_inputs(self, policy_params, old_policy_params):
"""Validate input parameters have correct format.""" """Validate input parameters have correct format."""

View File

@ -1,50 +0,0 @@
import jax
import jax.numpy as jnp
import time
from itpal_jax import FrobeniusProjection
def generate_params(key, batch_size, dim):
keys = jax.random.split(key, 2)
return {
"loc": jax.random.normal(keys[0], (batch_size, dim)),
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
}
def main():
# Test parameters
batch_size = 32
dim = 8
n_iterations = 1000
# Initialize projector
proj = FrobeniusProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
# Compile function
proj_fn = lambda p, op: proj.project(p, op)
proj_fn = jax.jit(proj_fn)
# Generate initial key
key = jax.random.PRNGKey(0)
# Warmup
for _ in range(10):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
# Time projections
start_time = time.time()
for _ in range(n_iterations):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
end_time = time.time()
print(f"Frobenius Projection:")
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
if __name__ == "__main__":
main()

View File

@ -1,49 +0,0 @@
import jax
import jax.numpy as jnp
import time
from itpal_jax import KLProjection
def generate_params(key, batch_size, dim):
keys = jax.random.split(key, 2)
return {
"loc": jax.random.normal(keys[0], (batch_size, dim)),
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
}
def main():
# Test parameters
batch_size = 32
dim = 8
n_iterations = 1000
# Initialize projector
proj = KLProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
# No JIT for KL projection since it uses C++ backend
proj_fn = lambda p, op: proj.project(p, op)
# Generate initial key
key = jax.random.PRNGKey(0)
# Warmup
for _ in range(10):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
# Time projections
start_time = time.time()
for _ in range(n_iterations):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
end_time = time.time()
print(f"KL Projection:")
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
if __name__ == "__main__":
main()

View File

@ -1,50 +0,0 @@
import jax
import jax.numpy as jnp
import time
from itpal_jax import WassersteinProjection
def generate_params(key, batch_size, dim):
keys = jax.random.split(key, 2)
return {
"loc": jax.random.normal(keys[0], (batch_size, dim)),
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
}
def main():
# Test parameters
batch_size = 32
dim = 8
n_iterations = 1000
# Initialize projector
proj = WassersteinProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
# Compile function
proj_fn = lambda p, op: proj.project(p, op)
proj_fn = jax.jit(proj_fn)
# Generate initial key
key = jax.random.PRNGKey(0)
# Warmup
for _ in range(10):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
# Time projections
start_time = time.time()
for _ in range(n_iterations):
key, subkey1, subkey2 = jax.random.split(key, 3)
params = generate_params(subkey1, batch_size, dim)
old_params = generate_params(subkey2, batch_size, dim)
proj_fn(params, old_params)
end_time = time.time()
print(f"Wasserstein Projection:")
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
if __name__ == "__main__":
main()