Fix JAX compatibility and CUDA module issues for HoReKa

- Update SLURM scripts to use correct CUDA modules (devel/cuda/12.4, intel compiler)
- Add JAX downgrade to 0.4.35 for CuDNN 9.5.1 compatibility
- Fix JAX_PLATFORMS environment variable (cuda vs gpu,cpu)
- Update README with cluster-specific JAX installation steps
- Tested successfully: Both PyTorch and JAX working on GPU with full training
This commit is contained in:
ys1087@partner.kit.edu 2025-07-22 16:36:06 +02:00
parent 336c96bb7b
commit 15750f56b2
3 changed files with 57 additions and 2 deletions

View File

@ -47,6 +47,11 @@ pip install --upgrade pip
pip install -r requirements/requirements.txt
pip install git+https://github.com/younggyoseo/mujoco_playground.git
# IMPORTANT: Downgrade JAX for HoReKa compatibility
# HoReKa has older NVIDIA drivers (CuDNN 9.5.1) that are incompatible with latest JAX
pip uninstall jax jaxlib jax-cuda12-plugin -y
pip install jax[cuda12]==0.4.35 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Test installation
python test_setup.py
```

View File

@ -12,7 +12,8 @@
# Load necessary modules
module purge
module load toolkit/CUDA/12.4
module load devel/cuda/12.4
module load compiler/intel/2025.1_llvm
# Navigate to the project directory
cd $SLURM_SUBMIT_DIR
@ -22,7 +23,8 @@ source .venv/bin/activate
# Set environment variables for proper GPU usage
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="gpu,cpu"
export JAX_PLATFORMS="cuda"
export JAX_ENABLE_X64=True
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
# export WANDB_API_KEY=your_api_key_here

48
run_fasttd3_dev.slurm Normal file
View File

@ -0,0 +1,48 @@
#!/bin/bash
#SBATCH --job-name=fasttd3_dev_test
#SBATCH --account=hk-project-p0022232
#SBATCH --partition=dev_accelerated
#SBATCH --time=00:30:00
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=4
#SBATCH --mem=16G
#SBATCH --output=fasttd3_dev_%j.out
#SBATCH --error=fasttd3_dev_%j.err
# Load necessary modules
module purge
module load devel/cuda/12.4
module load compiler/intel/2025.1_llvm
# Navigate to the project directory
cd $SLURM_SUBMIT_DIR
# Activate the virtual environment
source .venv/bin/activate
# Set environment variables for proper GPU usage
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="cuda"
export JAX_ENABLE_X64=True
# For testing, use offline mode
export WANDB_MODE=offline
echo "Starting FastTD3 dev test at $(date)"
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "Node: $(hostname)"
# Run FastTD3 training with minimal settings for quick test
python fast_td3/train.py \
--env_name T1JoystickFlatTerrain \
--exp_name FastTD3_Dev_Test \
--seed 42 \
--total_timesteps 5000 \
--num_envs 256 \
--batch_size 1024 \
--eval_interval 2500 \
--render_interval 0 \
--project FastTD3_HoReKa_Dev
echo "Job completed at $(date)"