reppo/reppo_alg/jaxrl/__init__.py
2025-07-21 18:31:20 -04:00

4 lines
73 B
Python

import jax
jax.config.update("jax_default_matmul_precision", "highest")