- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
51 lines
1.5 KiB
Bash
51 lines
1.5 KiB
Bash
#!/bin/bash
|
|
#SBATCH --job-name=fasttd3_full
|
|
#SBATCH --account=hk-project-p0022232
|
|
#SBATCH --partition=accelerated
|
|
#SBATCH --time=04:00:00
|
|
#SBATCH --gres=gpu:1
|
|
#SBATCH --ntasks=1
|
|
#SBATCH --cpus-per-task=8
|
|
#SBATCH --mem=64G
|
|
#SBATCH --output=fasttd3_full_%j.out
|
|
#SBATCH --error=fasttd3_full_%j.err
|
|
|
|
# Load necessary modules
|
|
module purge
|
|
module load devel/cuda/12.4
|
|
module load compiler/intel/2025.1_llvm
|
|
|
|
# Navigate to the project directory
|
|
cd $SLURM_SUBMIT_DIR
|
|
|
|
# Activate the virtual environment
|
|
source .venv/bin/activate
|
|
|
|
# Set environment variables for proper GPU usage
|
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
|
export JAX_PLATFORMS="cuda"
|
|
|
|
# Use online mode by default - set WANDB_API_KEY before running
|
|
export WANDB_MODE=online
|
|
|
|
echo "Starting FastTD3 full training at $(date)"
|
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
|
echo "Node: $(hostname)"
|
|
echo "Task: ${TASK_NAME:-T1JoystickFlatTerrain}"
|
|
echo "Seed: ${SEED:-42}"
|
|
echo "Timesteps: ${TOTAL_TIMESTEPS:-500000}"
|
|
|
|
# Run FastTD3 training with paper-based hyperparameters
|
|
python fast_td3/train.py \
|
|
--env_name ${TASK_NAME:-T1JoystickFlatTerrain} \
|
|
--exp_name FastTD3_Full_${TASK_NAME:-T1JoystickFlatTerrain} \
|
|
--seed ${SEED:-42} \
|
|
--total_timesteps ${TOTAL_TIMESTEPS:-500000} \
|
|
--num_envs ${NUM_ENVS:-2048} \
|
|
--batch_size ${BATCH_SIZE:-32768} \
|
|
--buffer_size ${BUFFER_SIZE:-102400} \
|
|
--eval_interval ${EVAL_INTERVAL:-25000} \
|
|
--render_interval 0 \
|
|
--project FastTD3_HoReKa_Full
|
|
|
|
echo "Job completed at $(date)" |