- Update SLURM scripts to use correct CUDA modules (devel/cuda/12.4, intel compiler) - Add JAX downgrade to 0.4.35 for CuDNN 9.5.1 compatibility - Fix JAX_PLATFORMS environment variable (cuda vs gpu,cpu) - Update README with cluster-specific JAX installation steps - Tested successfully: Both PyTorch and JAX working on GPU with full training
48 lines
1.2 KiB
Bash
48 lines
1.2 KiB
Bash
#!/bin/bash
|
|
#SBATCH --job-name=fasttd3_dev_test
|
|
#SBATCH --account=hk-project-p0022232
|
|
#SBATCH --partition=dev_accelerated
|
|
#SBATCH --time=00:30:00
|
|
#SBATCH --gres=gpu:1
|
|
#SBATCH --ntasks=1
|
|
#SBATCH --cpus-per-task=4
|
|
#SBATCH --mem=16G
|
|
#SBATCH --output=fasttd3_dev_%j.out
|
|
#SBATCH --error=fasttd3_dev_%j.err
|
|
|
|
# Load necessary modules
|
|
module purge
|
|
module load devel/cuda/12.4
|
|
module load compiler/intel/2025.1_llvm
|
|
|
|
# Navigate to the project directory
|
|
cd $SLURM_SUBMIT_DIR
|
|
|
|
# Activate the virtual environment
|
|
source .venv/bin/activate
|
|
|
|
# Set environment variables for proper GPU usage
|
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
|
export JAX_PLATFORMS="cuda"
|
|
export JAX_ENABLE_X64=True
|
|
|
|
# For testing, use offline mode
|
|
export WANDB_MODE=offline
|
|
|
|
echo "Starting FastTD3 dev test at $(date)"
|
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
|
echo "Node: $(hostname)"
|
|
|
|
# Run FastTD3 training with minimal settings for quick test
|
|
python fast_td3/train.py \
|
|
--env_name T1JoystickFlatTerrain \
|
|
--exp_name FastTD3_Dev_Test \
|
|
--seed 42 \
|
|
--total_timesteps 5000 \
|
|
--num_envs 256 \
|
|
--batch_size 1024 \
|
|
--eval_interval 2500 \
|
|
--render_interval 0 \
|
|
--project FastTD3_HoReKa_Dev
|
|
|
|
echo "Job completed at $(date)" |