diff --git a/README.md b/README.md index af6aefe..f8b77ab 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,11 @@ pip install --upgrade pip pip install -r requirements/requirements.txt pip install git+https://github.com/younggyoseo/mujoco_playground.git +# IMPORTANT: Downgrade JAX for HoReKa compatibility +# HoReKa has older NVIDIA drivers (CuDNN 9.5.1) that are incompatible with latest JAX +pip uninstall jax jaxlib jax-cuda12-plugin -y +pip install jax[cuda12]==0.4.35 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + # Test installation python test_setup.py ``` diff --git a/run_fasttd3.slurm b/run_fasttd3.slurm index 1f59e97..7382ba1 100644 --- a/run_fasttd3.slurm +++ b/run_fasttd3.slurm @@ -12,7 +12,8 @@ # Load necessary modules module purge -module load toolkit/CUDA/12.4 +module load devel/cuda/12.4 +module load compiler/intel/2025.1_llvm # Navigate to the project directory cd $SLURM_SUBMIT_DIR @@ -22,7 +23,8 @@ source .venv/bin/activate # Set environment variables for proper GPU usage export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID -export JAX_PLATFORMS="gpu,cpu" +export JAX_PLATFORMS="cuda" +export JAX_ENABLE_X64=True # Ensure wandb is logged in (set WANDB_API_KEY environment variable) # export WANDB_API_KEY=your_api_key_here diff --git a/run_fasttd3_dev.slurm b/run_fasttd3_dev.slurm new file mode 100644 index 0000000..cf14c2e --- /dev/null +++ b/run_fasttd3_dev.slurm @@ -0,0 +1,48 @@ +#!/bin/bash +#SBATCH --job-name=fasttd3_dev_test +#SBATCH --account=hk-project-p0022232 +#SBATCH --partition=dev_accelerated +#SBATCH --time=00:30:00 +#SBATCH --gres=gpu:1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=4 +#SBATCH --mem=16G +#SBATCH --output=fasttd3_dev_%j.out +#SBATCH --error=fasttd3_dev_%j.err + +# Load necessary modules +module purge +module load devel/cuda/12.4 +module load compiler/intel/2025.1_llvm + +# Navigate to the project directory +cd $SLURM_SUBMIT_DIR + +# Activate the virtual environment +source .venv/bin/activate + +# Set environment variables for proper GPU usage +export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID +export JAX_PLATFORMS="cuda" +export JAX_ENABLE_X64=True + +# For testing, use offline mode +export WANDB_MODE=offline + +echo "Starting FastTD3 dev test at $(date)" +echo "GPU: $CUDA_VISIBLE_DEVICES" +echo "Node: $(hostname)" + +# Run FastTD3 training with minimal settings for quick test +python fast_td3/train.py \ + --env_name T1JoystickFlatTerrain \ + --exp_name FastTD3_Dev_Test \ + --seed 42 \ + --total_timesteps 5000 \ + --num_envs 256 \ + --batch_size 1024 \ + --eval_interval 2500 \ + --render_interval 0 \ + --project FastTD3_HoReKa_Dev + +echo "Job completed at $(date)" \ No newline at end of file