- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
45 lines
1.2 KiB
Bash
45 lines
1.2 KiB
Bash
#!/bin/bash
|
|
#SBATCH --job-name=fasttd3_test
|
|
#SBATCH --account=hk-project-p0022232
|
|
#SBATCH --partition=accelerated
|
|
#SBATCH --time=02:00:00
|
|
#SBATCH --gres=gpu:1
|
|
#SBATCH --ntasks=1
|
|
#SBATCH --cpus-per-task=8
|
|
#SBATCH --mem=32G
|
|
#SBATCH --output=fasttd3_%j.out
|
|
#SBATCH --error=fasttd3_%j.err
|
|
|
|
# Load necessary modules
|
|
module purge
|
|
module load devel/cuda/12.4
|
|
module load compiler/intel/2025.1_llvm
|
|
|
|
# Navigate to the project directory
|
|
cd $SLURM_SUBMIT_DIR
|
|
|
|
# Activate the virtual environment
|
|
source .venv/bin/activate
|
|
|
|
# Set environment variables for proper GPU usage
|
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
|
export JAX_PLATFORMS="cuda"
|
|
|
|
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
|
|
# export WANDB_API_KEY=your_api_key_here
|
|
# Use online mode by default - set WANDB_API_KEY before running
|
|
export WANDB_MODE=online
|
|
|
|
# Run FastTD3 training with MuJoCo Playground environment
|
|
python fast_td3/train.py \
|
|
--env_name T1JoystickFlatTerrain \
|
|
--exp_name FastTD3_HoReKa_Test \
|
|
--seed 42 \
|
|
--total_timesteps 25000 \
|
|
--num_envs 1024 \
|
|
--batch_size 4096 \
|
|
--eval_interval 5000 \
|
|
--render_interval 0 \
|
|
--project FastTD3_HoReKa \
|
|
|
|
echo "Job completed at $(date)" |