4 lines
73 B
Python
4 lines
73 B
Python
import jax
|
|
|
|
jax.config.update("jax_default_matmul_precision", "highest")
|