- Update SLURM scripts to use correct CUDA modules (devel/cuda/12.4, intel compiler)
- Add JAX downgrade to 0.4.35 for CuDNN 9.5.1 compatibility
- Fix JAX_PLATFORMS environment variable (cuda vs gpu,cpu)
- Update README with cluster-specific JAX installation steps
- Tested successfully: Both PyTorch and JAX working on GPU with full training