Compare commits
No commits in common. "1096dbd8480d51caca555159741038237f1164c1" and "4d6ed9b3ace01d28f52189c33c23595621d24c6d" have entirely different histories.
1096dbd848
...
4d6ed9b3ac
@ -160,28 +160,26 @@ class KLProjection(BaseProjection):
|
|||||||
old_cov = old_scale_or_tril ** 2
|
old_cov = old_scale_or_tril ** 2
|
||||||
|
|
||||||
mask = cov_part > self.cov_bound
|
mask = cov_part > self.cov_bound
|
||||||
proj_scale_or_tril = scale_or_tril # Start with original scale
|
|
||||||
|
|
||||||
if mask.any():
|
# Always compute both branches and use matrix operations to select
|
||||||
if self.full_cov:
|
if self.full_cov:
|
||||||
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
proj_cov = project_full_covariance(cov, scale_or_tril, old_scale_or_tril, self.cov_bound)
|
||||||
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
is_invalid = jnp.isnan(proj_cov.mean(axis=(-2, -1)))
|
||||||
proj_scale_or_tril = jnp.where(is_invalid[..., None, None], old_scale_or_tril, scale_or_tril)
|
valid_mask = mask & ~is_invalid
|
||||||
mask = mask & ~is_invalid
|
|
||||||
chol = jnp.linalg.cholesky(proj_cov)
|
# Compute cholesky for all, let matrix ops handle selection
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None, None], chol, proj_scale_or_tril)
|
chol = jnp.linalg.cholesky(proj_cov)
|
||||||
else:
|
mask_matrix = valid_mask[..., None, None].astype(scale_or_tril.dtype)
|
||||||
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
return mask_matrix * chol + (1 - mask_matrix) * scale_or_tril
|
||||||
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
|
||||||
jnp.isinf(proj_cov.mean(axis=-1)) |
|
|
||||||
(proj_cov.min(axis=-1) < 0))
|
|
||||||
proj_scale_or_tril = jnp.where(is_invalid[..., None], old_scale_or_tril, scale_or_tril)
|
|
||||||
mask = mask & ~is_invalid
|
|
||||||
proj_scale_or_tril = jnp.where(mask[..., None], jnp.sqrt(proj_cov), scale_or_tril)
|
|
||||||
else:
|
else:
|
||||||
proj_scale_or_tril = scale_or_tril
|
proj_cov = project_diag_covariance(cov, old_cov, self.cov_bound)
|
||||||
|
is_invalid = (jnp.isnan(proj_cov.mean(axis=-1)) |
|
||||||
return proj_scale_or_tril
|
jnp.isinf(proj_cov.mean(axis=-1)) |
|
||||||
|
(proj_cov.min(axis=-1) < 0))
|
||||||
|
valid_mask = mask & ~is_invalid
|
||||||
|
|
||||||
|
mask_matrix = valid_mask[..., None].astype(scale_or_tril.dtype)
|
||||||
|
return mask_matrix * jnp.sqrt(proj_cov) + (1 - mask_matrix) * scale_or_tril
|
||||||
|
|
||||||
def _validate_inputs(self, policy_params, old_policy_params):
|
def _validate_inputs(self, policy_params, old_policy_params):
|
||||||
"""Validate input parameters have correct format."""
|
"""Validate input parameters have correct format."""
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import time
|
|
||||||
from itpal_jax import FrobeniusProjection
|
|
||||||
|
|
||||||
def generate_params(key, batch_size, dim):
|
|
||||||
keys = jax.random.split(key, 2)
|
|
||||||
return {
|
|
||||||
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
|
||||||
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
|
||||||
}
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Test parameters
|
|
||||||
batch_size = 32
|
|
||||||
dim = 8
|
|
||||||
n_iterations = 1000
|
|
||||||
|
|
||||||
# Initialize projector
|
|
||||||
proj = FrobeniusProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
|
||||||
|
|
||||||
# Compile function
|
|
||||||
proj_fn = lambda p, op: proj.project(p, op)
|
|
||||||
proj_fn = jax.jit(proj_fn)
|
|
||||||
|
|
||||||
# Generate initial key
|
|
||||||
key = jax.random.PRNGKey(0)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
|
|
||||||
# Time projections
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(n_iterations):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
print(f"Frobenius Projection:")
|
|
||||||
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
|
||||||
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,49 +0,0 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import time
|
|
||||||
from itpal_jax import KLProjection
|
|
||||||
|
|
||||||
def generate_params(key, batch_size, dim):
|
|
||||||
keys = jax.random.split(key, 2)
|
|
||||||
return {
|
|
||||||
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
|
||||||
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
|
||||||
}
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Test parameters
|
|
||||||
batch_size = 32
|
|
||||||
dim = 8
|
|
||||||
n_iterations = 1000
|
|
||||||
|
|
||||||
# Initialize projector
|
|
||||||
proj = KLProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
|
||||||
|
|
||||||
# No JIT for KL projection since it uses C++ backend
|
|
||||||
proj_fn = lambda p, op: proj.project(p, op)
|
|
||||||
|
|
||||||
# Generate initial key
|
|
||||||
key = jax.random.PRNGKey(0)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
|
|
||||||
# Time projections
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(n_iterations):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
print(f"KL Projection:")
|
|
||||||
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
|
||||||
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,50 +0,0 @@
|
|||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import time
|
|
||||||
from itpal_jax import WassersteinProjection
|
|
||||||
|
|
||||||
def generate_params(key, batch_size, dim):
|
|
||||||
keys = jax.random.split(key, 2)
|
|
||||||
return {
|
|
||||||
"loc": jax.random.normal(keys[0], (batch_size, dim)),
|
|
||||||
"scale": jax.nn.softplus(jax.random.normal(keys[1], (batch_size, dim)))
|
|
||||||
}
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Test parameters
|
|
||||||
batch_size = 32
|
|
||||||
dim = 8
|
|
||||||
n_iterations = 1000
|
|
||||||
|
|
||||||
# Initialize projector
|
|
||||||
proj = WassersteinProjection(mean_bound=0.1, cov_bound=0.1, contextual_std=True)
|
|
||||||
|
|
||||||
# Compile function
|
|
||||||
proj_fn = lambda p, op: proj.project(p, op)
|
|
||||||
proj_fn = jax.jit(proj_fn)
|
|
||||||
|
|
||||||
# Generate initial key
|
|
||||||
key = jax.random.PRNGKey(0)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(10):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
|
|
||||||
# Time projections
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(n_iterations):
|
|
||||||
key, subkey1, subkey2 = jax.random.split(key, 3)
|
|
||||||
params = generate_params(subkey1, batch_size, dim)
|
|
||||||
old_params = generate_params(subkey2, batch_size, dim)
|
|
||||||
proj_fn(params, old_params)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
print(f"Wasserstein Projection:")
|
|
||||||
print(f"Average time per projection: {(end_time - start_time) / n_iterations * 1000:.3f} ms")
|
|
||||||
print(f"Total time for {n_iterations} iterations: {end_time - start_time:.3f} s")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in New Issue
Block a user