Fix JAX compatibility and CUDA module issues for HoReKa
- Update SLURM scripts to use correct CUDA modules (devel/cuda/12.4, intel compiler) - Add JAX downgrade to 0.4.35 for CuDNN 9.5.1 compatibility - Fix JAX_PLATFORMS environment variable (cuda vs gpu,cpu) - Update README with cluster-specific JAX installation steps - Tested successfully: Both PyTorch and JAX working on GPU with full training
This commit is contained in:
parent
336c96bb7b
commit
15750f56b2
@ -47,6 +47,11 @@ pip install --upgrade pip
|
||||
pip install -r requirements/requirements.txt
|
||||
pip install git+https://github.com/younggyoseo/mujoco_playground.git
|
||||
|
||||
# IMPORTANT: Downgrade JAX for HoReKa compatibility
|
||||
# HoReKa has older NVIDIA drivers (CuDNN 9.5.1) that are incompatible with latest JAX
|
||||
pip uninstall jax jaxlib jax-cuda12-plugin -y
|
||||
pip install jax[cuda12]==0.4.35 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Test installation
|
||||
python test_setup.py
|
||||
```
|
||||
|
@ -12,7 +12,8 @@
|
||||
|
||||
# Load necessary modules
|
||||
module purge
|
||||
module load toolkit/CUDA/12.4
|
||||
module load devel/cuda/12.4
|
||||
module load compiler/intel/2025.1_llvm
|
||||
|
||||
# Navigate to the project directory
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
@ -22,7 +23,8 @@ source .venv/bin/activate
|
||||
|
||||
# Set environment variables for proper GPU usage
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="gpu,cpu"
|
||||
export JAX_PLATFORMS="cuda"
|
||||
export JAX_ENABLE_X64=True
|
||||
|
||||
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
|
||||
# export WANDB_API_KEY=your_api_key_here
|
||||
|
48
run_fasttd3_dev.slurm
Normal file
48
run_fasttd3_dev.slurm
Normal file
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=fasttd3_dev_test
|
||||
#SBATCH --account=hk-project-p0022232
|
||||
#SBATCH --partition=dev_accelerated
|
||||
#SBATCH --time=00:30:00
|
||||
#SBATCH --gres=gpu:1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=4
|
||||
#SBATCH --mem=16G
|
||||
#SBATCH --output=fasttd3_dev_%j.out
|
||||
#SBATCH --error=fasttd3_dev_%j.err
|
||||
|
||||
# Load necessary modules
|
||||
module purge
|
||||
module load devel/cuda/12.4
|
||||
module load compiler/intel/2025.1_llvm
|
||||
|
||||
# Navigate to the project directory
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
|
||||
# Activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set environment variables for proper GPU usage
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="cuda"
|
||||
export JAX_ENABLE_X64=True
|
||||
|
||||
# For testing, use offline mode
|
||||
export WANDB_MODE=offline
|
||||
|
||||
echo "Starting FastTD3 dev test at $(date)"
|
||||
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||
echo "Node: $(hostname)"
|
||||
|
||||
# Run FastTD3 training with minimal settings for quick test
|
||||
python fast_td3/train.py \
|
||||
--env_name T1JoystickFlatTerrain \
|
||||
--exp_name FastTD3_Dev_Test \
|
||||
--seed 42 \
|
||||
--total_timesteps 5000 \
|
||||
--num_envs 256 \
|
||||
--batch_size 1024 \
|
||||
--eval_interval 2500 \
|
||||
--render_interval 0 \
|
||||
--project FastTD3_HoReKa_Dev
|
||||
|
||||
echo "Job completed at $(date)"
|
Loading…
Reference in New Issue
Block a user