- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
47 lines
1.2 KiB
Bash
47 lines
1.2 KiB
Bash
#!/bin/bash
|
|
#SBATCH --job-name=fasttd3_dev_test
|
|
#SBATCH --account=hk-project-p0022232
|
|
#SBATCH --partition=dev_accelerated
|
|
#SBATCH --time=00:30:00
|
|
#SBATCH --gres=gpu:1
|
|
#SBATCH --ntasks=1
|
|
#SBATCH --cpus-per-task=4
|
|
#SBATCH --mem=16G
|
|
#SBATCH --output=fasttd3_dev_%j.out
|
|
#SBATCH --error=fasttd3_dev_%j.err
|
|
|
|
# Load necessary modules
|
|
module purge
|
|
module load devel/cuda/12.4
|
|
module load compiler/intel/2025.1_llvm
|
|
|
|
# Navigate to the project directory
|
|
cd $SLURM_SUBMIT_DIR
|
|
|
|
# Activate the virtual environment
|
|
source .venv/bin/activate
|
|
|
|
# Set environment variables for proper GPU usage
|
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
|
export JAX_PLATFORMS="cuda"
|
|
|
|
# Use online mode by default - set WANDB_API_KEY before running
|
|
export WANDB_MODE=online
|
|
|
|
echo "Starting FastTD3 dev test at $(date)"
|
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
|
echo "Node: $(hostname)"
|
|
|
|
# Run FastTD3 training with minimal settings for quick test
|
|
python fast_td3/train.py \
|
|
--env_name T1JoystickFlatTerrain \
|
|
--exp_name FastTD3_Dev_Test \
|
|
--seed 42 \
|
|
--total_timesteps 5000 \
|
|
--num_envs 256 \
|
|
--batch_size 1024 \
|
|
--eval_interval 2500 \
|
|
--render_interval 0 \
|
|
--project FastTD3_HoReKa_Dev \
|
|
|
|
echo "Job completed at $(date)" |