Add FastTD3 HoReKa experiment management system
- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
This commit is contained in:
parent
15750f56b2
commit
e7e3ae48f1
22
.gitignore
vendored
22
.gitignore
vendored
@ -1,5 +1,23 @@
|
|||||||
models
|
# Model checkpoints and training outputs
|
||||||
wandb
|
models/
|
||||||
|
wandb/
|
||||||
|
|
||||||
|
# Log files and job outputs
|
||||||
|
logs/
|
||||||
|
*.out
|
||||||
|
*.err
|
||||||
|
fasttd3_*.out
|
||||||
|
fasttd3_*.err
|
||||||
|
|
||||||
|
# Experiment tracking
|
||||||
|
experiment_tracking_*.yaml
|
||||||
|
|
||||||
|
# Python
|
||||||
*.pyc
|
*.pyc
|
||||||
|
__pycache__/
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
fast_td3.egg-info/
|
fast_td3.egg-info/
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# SLURM scripts (generated)
|
||||||
|
scripts/
|
||||||
|
11
README.md
11
README.md
@ -88,25 +88,26 @@ The setup includes:
|
|||||||
- **Test script** (`test_setup.py`) for environment verification
|
- **Test script** (`test_setup.py`) for environment verification
|
||||||
- **MuJoCo Playground environment** (`T1JoystickFlatTerrain`) for humanoid control
|
- **MuJoCo Playground environment** (`T1JoystickFlatTerrain`) for humanoid control
|
||||||
- **Automatic GPU detection** and CUDA 12.4 compatibility
|
- **Automatic GPU detection** and CUDA 12.4 compatibility
|
||||||
- **Wandb logging** with offline mode support
|
- **Wandb logging** with online mode by default
|
||||||
|
|
||||||
### Wandb Integration
|
### Wandb Integration
|
||||||
|
|
||||||
The scripts support both online and offline wandb logging:
|
The scripts support both online and offline wandb logging:
|
||||||
|
|
||||||
**Online mode:**
|
**Online mode (default):**
|
||||||
```bash
|
```bash
|
||||||
export WANDB_API_KEY=your_api_key_here
|
export WANDB_API_KEY=your_api_key_here
|
||||||
python submit_job.py
|
python submit_job.py
|
||||||
# Select 'y' when prompted for online mode
|
# Select 'y' when prompted for online mode (default)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Offline mode (default):**
|
**Offline mode:**
|
||||||
```bash
|
```bash
|
||||||
# Jobs run in offline mode by default
|
# Select 'n' when prompted for online mode
|
||||||
# Sync later with: wandb sync <run_directory>
|
# Sync later with: wandb sync <run_directory>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
# ORIGINAL README:
|
# ORIGINAL README:
|
||||||
|
67
experiment_plan.md
Normal file
67
experiment_plan.md
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# FastTD3 HoReKa Experiment Plan
|
||||||
|
*Added by Dominik - Paper Replication Study*
|
||||||
|
|
||||||
|
## ✅ Proof of Concept Results
|
||||||
|
**Initial Success**: [HoReKa Dev Run](https://wandb.ai/rl-network-scaling/FastTD3_HoReKa_Dev?nw=nwuserdominik_roth)
|
||||||
|
|
||||||
|
- **Task**: T1JoystickFlatTerrain
|
||||||
|
- **Duration**: 7 minutes (5000 timesteps)
|
||||||
|
- **Performance**: Successfully training at ~29 it/s
|
||||||
|
- **Key Achievement**: Fixed JAX/PyTorch dtype mismatch issue (removed JAX_ENABLE_X64)
|
||||||
|
- **Status**: ✅ Environment working, ready for full-scale experiments
|
||||||
|
|
||||||
|
## Experiments to Replicate
|
||||||
|
|
||||||
|
### Phase 1: MuJoCo Playground (Figure 11 from paper)
|
||||||
|
- `T1JoystickFlatTerrain` (3600s)
|
||||||
|
- `T1JoystickRoughTerrain` (3600s)
|
||||||
|
- `G1JoystickFlatTerrain` (3600s)
|
||||||
|
- `G1JoystickRoughTerrain` (3600s)
|
||||||
|
|
||||||
|
**Hyperparameters (from paper):**
|
||||||
|
- total_timesteps: 500000
|
||||||
|
- num_envs: 2048
|
||||||
|
- batch_size: 32768
|
||||||
|
- buffer_size: 102400 (50K per env)
|
||||||
|
- eval_interval: 25000
|
||||||
|
|
||||||
|
### Phase 2: IsaacLab (Figure 10 from paper)
|
||||||
|
- `Isaac-Velocity-Flat-G1-v0` (3600s)
|
||||||
|
- `Isaac-Velocity-Rough-G1-v0` (3600s)
|
||||||
|
- `Isaac-Repose-Cube-Allegro-Direct-v0` (3600s)
|
||||||
|
- `Isaac-Repose-Cube-Shadow-Direct-v0` (3600s)
|
||||||
|
- `Isaac-Velocity-Flat-H1-v0` (3600s)
|
||||||
|
- `Isaac-Velocity-Rough-H1-v0` (3600s)
|
||||||
|
|
||||||
|
**Hyperparameters:**
|
||||||
|
- total_timesteps: 1000000
|
||||||
|
- num_envs: 1024
|
||||||
|
- batch_size: 32768
|
||||||
|
- buffer_size: 51200
|
||||||
|
- eval_interval: 50000
|
||||||
|
|
||||||
|
### Phase 3: HumanoidBench (Figure 9 from paper - subset)
|
||||||
|
- `h1hand-walk` (10800s)
|
||||||
|
- `h1hand-run` (10800s)
|
||||||
|
- `h1hand-hurdle` (10800s)
|
||||||
|
- `h1hand-stair` (10800s)
|
||||||
|
- `h1hand-slide` (10800s)
|
||||||
|
|
||||||
|
**Hyperparameters:**
|
||||||
|
- total_timesteps: 2000000
|
||||||
|
- num_envs: 256
|
||||||
|
- batch_size: 16384
|
||||||
|
- buffer_size: 12800
|
||||||
|
- eval_interval: 100000
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Submit Phase 1:
|
||||||
|
```bash
|
||||||
|
python submit_experiment_batch.py --phase 1 --seeds 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Monitor progress:
|
||||||
|
```bash
|
||||||
|
python monitor_experiments.py --watch
|
||||||
|
```
|
168
monitor_experiments.py
Executable file
168
monitor_experiments.py
Executable file
@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Monitor FastTD3 experiments running on HoReKa cluster.
|
||||||
|
Usage: python monitor_experiments.py [tracking_file.yaml]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def get_job_status():
|
||||||
|
"""Get current SLURM job status for user."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['squeue', '-u', os.environ['USER'], '--format=%i,%j,%t,%M,%N'],
|
||||||
|
capture_output=True, text=True, check=True)
|
||||||
|
|
||||||
|
jobs = {}
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
if len(lines) > 1: # Skip header
|
||||||
|
for line in lines[1:]:
|
||||||
|
parts = line.split(',')
|
||||||
|
if len(parts) >= 5:
|
||||||
|
job_id, name, state, time_used, node = parts
|
||||||
|
jobs[job_id] = {
|
||||||
|
'name': name,
|
||||||
|
'state': state,
|
||||||
|
'time_used': time_used,
|
||||||
|
'node': node
|
||||||
|
}
|
||||||
|
return jobs
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_wandb_run(project, run_name):
|
||||||
|
"""Check if wandb run exists and get basic stats."""
|
||||||
|
# Note: This requires wandb API access
|
||||||
|
# For now, just return placeholder
|
||||||
|
return {
|
||||||
|
'exists': True,
|
||||||
|
'last_logged': datetime.now().isoformat(),
|
||||||
|
'status': 'running'
|
||||||
|
}
|
||||||
|
|
||||||
|
def format_duration(seconds):
|
||||||
|
"""Format duration in human readable format."""
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds}s"
|
||||||
|
elif seconds < 3600:
|
||||||
|
return f"{seconds//60}m {seconds%60}s"
|
||||||
|
else:
|
||||||
|
return f"{seconds//3600}h {(seconds%3600)//60}m"
|
||||||
|
|
||||||
|
def print_status_table(jobs_status):
|
||||||
|
"""Print formatted status table."""
|
||||||
|
print(f"{'Job ID':<10} {'Task':<25} {'Seed':<4} {'State':<12} {'Runtime':<12} {'Node':<10}")
|
||||||
|
print("=" * 85)
|
||||||
|
|
||||||
|
for job_info in jobs_status:
|
||||||
|
job_id = job_info.get('job_id', 'N/A')
|
||||||
|
task = job_info.get('task', 'Unknown')[:24]
|
||||||
|
seed = str(job_info.get('seed', '?'))
|
||||||
|
|
||||||
|
slurm_info = job_info.get('slurm_status', {})
|
||||||
|
state = slurm_info.get('state', 'UNKNOWN')
|
||||||
|
time_used = slurm_info.get('time_used', 'N/A')
|
||||||
|
node = slurm_info.get('node', 'N/A')
|
||||||
|
|
||||||
|
# Color coding for states
|
||||||
|
if state == 'RUNNING':
|
||||||
|
state_color = f"\\033[92m{state}\\033[0m" # Green
|
||||||
|
elif state == 'PENDING':
|
||||||
|
state_color = f"\\033[93m{state}\\033[0m" # Yellow
|
||||||
|
elif state == 'COMPLETED':
|
||||||
|
state_color = f"\\033[94m{state}\\033[0m" # Blue
|
||||||
|
elif state in ['FAILED', 'CANCELLED']:
|
||||||
|
state_color = f"\\033[91m{state}\\033[0m" # Red
|
||||||
|
else:
|
||||||
|
state_color = state
|
||||||
|
|
||||||
|
print(f"{job_id:<10} {task:<25} {seed:<4} {state_color:<20} {time_used:<12} {node:<10}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Monitor FastTD3 experiments')
|
||||||
|
parser.add_argument('tracking_file', nargs='?',
|
||||||
|
help='YAML tracking file from batch submission')
|
||||||
|
parser.add_argument('--watch', '-w', action='store_true',
|
||||||
|
help='Continuously monitor (refresh every 30s)')
|
||||||
|
parser.add_argument('--summary', '-s', action='store_true',
|
||||||
|
help='Show summary statistics only')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def update_and_display():
|
||||||
|
os.system('clear' if os.name == 'posix' else 'cls')
|
||||||
|
|
||||||
|
print(f"🔍 FastTD3 HoReKa Experiment Monitor")
|
||||||
|
print(f"📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Get current SLURM status
|
||||||
|
slurm_jobs = get_job_status()
|
||||||
|
|
||||||
|
if args.tracking_file and os.path.exists(args.tracking_file):
|
||||||
|
# Load tracking file
|
||||||
|
with open(args.tracking_file, 'r') as f:
|
||||||
|
tracking_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
jobs_status = []
|
||||||
|
for job in tracking_data['jobs']:
|
||||||
|
job_id = job['job_id']
|
||||||
|
job_info = job.copy()
|
||||||
|
job_info['slurm_status'] = slurm_jobs.get(job_id, {'state': 'NOT_FOUND'})
|
||||||
|
jobs_status.append(job_info)
|
||||||
|
|
||||||
|
print_status_table(jobs_status)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
total_jobs = len(jobs_status)
|
||||||
|
states = {}
|
||||||
|
for job in jobs_status:
|
||||||
|
state = job['slurm_status']['state']
|
||||||
|
states[state] = states.get(state, 0) + 1
|
||||||
|
|
||||||
|
print(f"\\n📊 Summary ({total_jobs} jobs):")
|
||||||
|
for state, count in sorted(states.items()):
|
||||||
|
percentage = (count / total_jobs) * 100
|
||||||
|
print(f" {state}: {count} ({percentage:.1f}%)")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Show all user jobs if no tracking file
|
||||||
|
if slurm_jobs:
|
||||||
|
print("📋 All SLURM jobs:")
|
||||||
|
jobs_list = []
|
||||||
|
for job_id, info in slurm_jobs.items():
|
||||||
|
jobs_list.append({
|
||||||
|
'job_id': job_id,
|
||||||
|
'task': info['name'],
|
||||||
|
'seed': '?',
|
||||||
|
'slurm_status': info
|
||||||
|
})
|
||||||
|
print_status_table(jobs_list)
|
||||||
|
else:
|
||||||
|
print("✅ No active jobs found")
|
||||||
|
|
||||||
|
print(f"\\n💡 Commands:")
|
||||||
|
print(f" squeue -u $USER # Detailed SLURM status")
|
||||||
|
print(f" scancel <job_id> # Cancel specific job")
|
||||||
|
print(f" tail -f logs/*.out # Follow job logs")
|
||||||
|
print(f" python collect_results.py # Gather completed results")
|
||||||
|
|
||||||
|
if args.watch:
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
update_and_display()
|
||||||
|
print(f"\\n⏱️ Refreshing in 30s... (Ctrl+C to stop)")
|
||||||
|
time.sleep(30)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\\n👋 Monitoring stopped")
|
||||||
|
else:
|
||||||
|
update_and_display()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -24,12 +24,11 @@ source .venv/bin/activate
|
|||||||
# Set environment variables for proper GPU usage
|
# Set environment variables for proper GPU usage
|
||||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||||
export JAX_PLATFORMS="cuda"
|
export JAX_PLATFORMS="cuda"
|
||||||
export JAX_ENABLE_X64=True
|
|
||||||
|
|
||||||
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
|
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
|
||||||
# export WANDB_API_KEY=your_api_key_here
|
# export WANDB_API_KEY=your_api_key_here
|
||||||
# For testing, use offline mode
|
# Use online mode by default - set WANDB_API_KEY before running
|
||||||
export WANDB_MODE=offline
|
export WANDB_MODE=online
|
||||||
|
|
||||||
# Run FastTD3 training with MuJoCo Playground environment
|
# Run FastTD3 training with MuJoCo Playground environment
|
||||||
python fast_td3/train.py \
|
python fast_td3/train.py \
|
||||||
@ -41,6 +40,6 @@ python fast_td3/train.py \
|
|||||||
--batch_size 4096 \
|
--batch_size 4096 \
|
||||||
--eval_interval 5000 \
|
--eval_interval 5000 \
|
||||||
--render_interval 0 \
|
--render_interval 0 \
|
||||||
--project FastTD3_HoReKa
|
--project FastTD3_HoReKa \
|
||||||
|
|
||||||
echo "Job completed at $(date)"
|
echo "Job completed at $(date)"
|
@ -24,10 +24,9 @@ source .venv/bin/activate
|
|||||||
# Set environment variables for proper GPU usage
|
# Set environment variables for proper GPU usage
|
||||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||||
export JAX_PLATFORMS="cuda"
|
export JAX_PLATFORMS="cuda"
|
||||||
export JAX_ENABLE_X64=True
|
|
||||||
|
|
||||||
# For testing, use offline mode
|
# Use online mode by default - set WANDB_API_KEY before running
|
||||||
export WANDB_MODE=offline
|
export WANDB_MODE=online
|
||||||
|
|
||||||
echo "Starting FastTD3 dev test at $(date)"
|
echo "Starting FastTD3 dev test at $(date)"
|
||||||
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||||
@ -43,6 +42,6 @@ python fast_td3/train.py \
|
|||||||
--batch_size 1024 \
|
--batch_size 1024 \
|
||||||
--eval_interval 2500 \
|
--eval_interval 2500 \
|
||||||
--render_interval 0 \
|
--render_interval 0 \
|
||||||
--project FastTD3_HoReKa_Dev
|
--project FastTD3_HoReKa_Dev \
|
||||||
|
|
||||||
echo "Job completed at $(date)"
|
echo "Job completed at $(date)"
|
51
run_fasttd3_full.slurm
Normal file
51
run_fasttd3_full.slurm
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=fasttd3_full
|
||||||
|
#SBATCH --account=hk-project-p0022232
|
||||||
|
#SBATCH --partition=accelerated
|
||||||
|
#SBATCH --time=04:00:00
|
||||||
|
#SBATCH --gres=gpu:1
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
#SBATCH --mem=64G
|
||||||
|
#SBATCH --output=fasttd3_full_%j.out
|
||||||
|
#SBATCH --error=fasttd3_full_%j.err
|
||||||
|
|
||||||
|
# Load necessary modules
|
||||||
|
module purge
|
||||||
|
module load devel/cuda/12.4
|
||||||
|
module load compiler/intel/2025.1_llvm
|
||||||
|
|
||||||
|
# Navigate to the project directory
|
||||||
|
cd $SLURM_SUBMIT_DIR
|
||||||
|
|
||||||
|
# Activate the virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set environment variables for proper GPU usage
|
||||||
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||||
|
export JAX_PLATFORMS="cuda"
|
||||||
|
|
||||||
|
# Use online mode by default - set WANDB_API_KEY before running
|
||||||
|
export WANDB_MODE=online
|
||||||
|
|
||||||
|
echo "Starting FastTD3 full training at $(date)"
|
||||||
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||||
|
echo "Node: $(hostname)"
|
||||||
|
echo "Task: ${TASK_NAME:-T1JoystickFlatTerrain}"
|
||||||
|
echo "Seed: ${SEED:-42}"
|
||||||
|
echo "Timesteps: ${TOTAL_TIMESTEPS:-500000}"
|
||||||
|
|
||||||
|
# Run FastTD3 training with paper-based hyperparameters
|
||||||
|
python fast_td3/train.py \
|
||||||
|
--env_name ${TASK_NAME:-T1JoystickFlatTerrain} \
|
||||||
|
--exp_name FastTD3_Full_${TASK_NAME:-T1JoystickFlatTerrain} \
|
||||||
|
--seed ${SEED:-42} \
|
||||||
|
--total_timesteps ${TOTAL_TIMESTEPS:-500000} \
|
||||||
|
--num_envs ${NUM_ENVS:-2048} \
|
||||||
|
--batch_size ${BATCH_SIZE:-32768} \
|
||||||
|
--buffer_size ${BUFFER_SIZE:-102400} \
|
||||||
|
--eval_interval ${EVAL_INTERVAL:-25000} \
|
||||||
|
--render_interval 0 \
|
||||||
|
--project FastTD3_HoReKa_Full
|
||||||
|
|
||||||
|
echo "Job completed at $(date)"
|
225
submit_experiment_batch.py
Executable file
225
submit_experiment_batch.py
Executable file
@ -0,0 +1,225 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Batch experiment submission script for FastTD3 paper replication.
|
||||||
|
Usage: python submit_experiment_batch.py --phase 1 --tasks all --seeds 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Experiment configuration based on paper
|
||||||
|
EXPERIMENT_CONFIG = {
|
||||||
|
"phase1_mujoco": {
|
||||||
|
"tasks": [
|
||||||
|
"T1JoystickFlatTerrain",
|
||||||
|
"T1JoystickRoughTerrain",
|
||||||
|
"G1JoystickFlatTerrain",
|
||||||
|
"G1JoystickRoughTerrain"
|
||||||
|
],
|
||||||
|
"total_timesteps": 500000,
|
||||||
|
"num_envs": 2048,
|
||||||
|
"batch_size": 32768,
|
||||||
|
"buffer_size": 102400, # 50K per env for 2048 envs
|
||||||
|
"eval_interval": 25000,
|
||||||
|
"time_limit": "04:00:00",
|
||||||
|
"mem": "64G"
|
||||||
|
},
|
||||||
|
"phase2_isaaclab": {
|
||||||
|
"tasks": [
|
||||||
|
"Isaac-Velocity-Flat-G1-v0",
|
||||||
|
"Isaac-Velocity-Rough-G1-v0",
|
||||||
|
"Isaac-Repose-Cube-Allegro-Direct-v0",
|
||||||
|
"Isaac-Repose-Cube-Shadow-Direct-v0",
|
||||||
|
"Isaac-Velocity-Flat-H1-v0",
|
||||||
|
"Isaac-Velocity-Rough-H1-v0"
|
||||||
|
],
|
||||||
|
"total_timesteps": 1000000,
|
||||||
|
"num_envs": 1024,
|
||||||
|
"batch_size": 32768,
|
||||||
|
"buffer_size": 51200,
|
||||||
|
"eval_interval": 50000,
|
||||||
|
"time_limit": "04:00:00",
|
||||||
|
"mem": "64G"
|
||||||
|
},
|
||||||
|
"phase3_humanoidbench": {
|
||||||
|
"tasks": [
|
||||||
|
"h1hand-walk",
|
||||||
|
"h1hand-run",
|
||||||
|
"h1hand-hurdle",
|
||||||
|
"h1hand-stair",
|
||||||
|
"h1hand-slide"
|
||||||
|
],
|
||||||
|
"total_timesteps": 2000000,
|
||||||
|
"num_envs": 256,
|
||||||
|
"batch_size": 16384,
|
||||||
|
"buffer_size": 12800,
|
||||||
|
"eval_interval": 100000,
|
||||||
|
"time_limit": "12:00:00", # 12 hours for HumanoidBench
|
||||||
|
"mem": "64G"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_job_script(task, config, seed, phase):
|
||||||
|
"""Create SLURM script for specific task/seed combination."""
|
||||||
|
|
||||||
|
script_content = f'''#!/bin/bash
|
||||||
|
#SBATCH --job-name=fasttd3_{phase}_{task.replace("-", "_")}_s{seed}
|
||||||
|
#SBATCH --account=hk-project-p0022232
|
||||||
|
#SBATCH --partition=accelerated
|
||||||
|
#SBATCH --time={config["time_limit"]}
|
||||||
|
#SBATCH --gres=gpu:1
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --cpus-per-task=8
|
||||||
|
#SBATCH --mem={config["mem"]}
|
||||||
|
#SBATCH --output=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.out
|
||||||
|
#SBATCH --error=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.err
|
||||||
|
|
||||||
|
# Load necessary modules
|
||||||
|
module purge
|
||||||
|
module load devel/cuda/12.4
|
||||||
|
module load compiler/intel/2025.1_llvm
|
||||||
|
|
||||||
|
# Navigate to the project directory
|
||||||
|
cd $SLURM_SUBMIT_DIR
|
||||||
|
|
||||||
|
# Activate the virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||||
|
export JAX_PLATFORMS="cuda"
|
||||||
|
export WANDB_MODE=online
|
||||||
|
|
||||||
|
echo "Starting FastTD3 {phase} training at $(date)"
|
||||||
|
echo "Task: {task}"
|
||||||
|
echo "Seed: {seed}"
|
||||||
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||||
|
echo "Node: $(hostname)"
|
||||||
|
|
||||||
|
# Run FastTD3 training
|
||||||
|
python fast_td3/train.py \\
|
||||||
|
--env_name {task} \\
|
||||||
|
--exp_name FastTD3_{phase}_{task.replace("-", "_")} \\
|
||||||
|
--seed {seed} \\
|
||||||
|
--total_timesteps {config["total_timesteps"]} \\
|
||||||
|
--num_envs {config["num_envs"]} \\
|
||||||
|
--batch_size {config["batch_size"]} \\
|
||||||
|
--buffer_size {config["buffer_size"]} \\
|
||||||
|
--eval_interval {config["eval_interval"]} \\
|
||||||
|
--render_interval 0 \\
|
||||||
|
--project FastTD3_HoReKa_{phase.title()}
|
||||||
|
|
||||||
|
echo "Job completed at $(date)"
|
||||||
|
'''
|
||||||
|
|
||||||
|
return script_content
|
||||||
|
|
||||||
|
def submit_job(script_path, dry_run=False):
|
||||||
|
"""Submit SLURM job and return job ID."""
|
||||||
|
if dry_run:
|
||||||
|
print(f"[DRY RUN] Would submit: {script_path}")
|
||||||
|
return "12345" # Fake job ID
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['sbatch', script_path],
|
||||||
|
capture_output=True, text=True, check=True)
|
||||||
|
job_id = result.stdout.strip().split()[-1]
|
||||||
|
print(f"✅ Submitted {script_path} -> Job ID: {job_id}")
|
||||||
|
return job_id
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"❌ Failed to submit {script_path}: {e.stderr}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Submit FastTD3 experiment batch')
|
||||||
|
parser.add_argument('--phase', type=str, choices=['1', '2', '3', 'all'],
|
||||||
|
default='1', help='Experiment phase to run')
|
||||||
|
parser.add_argument('--tasks', type=str, default='all',
|
||||||
|
help='Comma-separated task names or "all"')
|
||||||
|
parser.add_argument('--seeds', type=int, default=3,
|
||||||
|
help='Number of random seeds to run')
|
||||||
|
parser.add_argument('--dry-run', action='store_true',
|
||||||
|
help='Print commands without executing')
|
||||||
|
parser.add_argument('--delay', type=int, default=5,
|
||||||
|
help='Delay between job submissions (seconds)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
os.makedirs('scripts', exist_ok=True)
|
||||||
|
|
||||||
|
# Determine which phases to run
|
||||||
|
if args.phase == 'all':
|
||||||
|
phases = ['phase1_mujoco', 'phase2_isaaclab', 'phase3_humanoidbench']
|
||||||
|
else:
|
||||||
|
phase_map = {'1': 'phase1_mujoco', '2': 'phase2_isaaclab', '3': 'phase3_humanoidbench'}
|
||||||
|
phases = [phase_map[args.phase]]
|
||||||
|
|
||||||
|
submitted_jobs = []
|
||||||
|
|
||||||
|
for phase in phases:
|
||||||
|
config = EXPERIMENT_CONFIG[phase]
|
||||||
|
|
||||||
|
# Determine tasks to run
|
||||||
|
if args.tasks == 'all':
|
||||||
|
tasks = config['tasks']
|
||||||
|
else:
|
||||||
|
tasks = [t.strip() for t in args.tasks.split(',')]
|
||||||
|
# Validate tasks exist in config
|
||||||
|
invalid = set(tasks) - set(config['tasks'])
|
||||||
|
if invalid:
|
||||||
|
print(f"❌ Invalid tasks for {phase}: {invalid}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\\n🚀 Starting {phase} with tasks: {tasks}")
|
||||||
|
print(f" Seeds: {list(range(1, args.seeds + 1))}")
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
for seed in range(1, args.seeds + 1):
|
||||||
|
# Create job script
|
||||||
|
script_content = create_job_script(task, config, seed, phase)
|
||||||
|
script_name = f"scripts/fasttd3_{phase}_{task.replace('-', '_')}_s{seed}.slurm"
|
||||||
|
|
||||||
|
with open(script_name, 'w') as f:
|
||||||
|
f.write(script_content)
|
||||||
|
|
||||||
|
# Submit job
|
||||||
|
job_id = submit_job(script_name, args.dry_run)
|
||||||
|
if job_id:
|
||||||
|
submitted_jobs.append({
|
||||||
|
'job_id': job_id,
|
||||||
|
'phase': phase,
|
||||||
|
'task': task,
|
||||||
|
'seed': seed,
|
||||||
|
'script': script_name
|
||||||
|
})
|
||||||
|
|
||||||
|
# Delay between submissions to avoid overwhelming scheduler
|
||||||
|
if not args.dry_run and args.delay > 0:
|
||||||
|
time.sleep(args.delay)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(f"\\n📊 Submission Summary:")
|
||||||
|
print(f" Total jobs submitted: {len(submitted_jobs)}")
|
||||||
|
|
||||||
|
if submitted_jobs:
|
||||||
|
# Save job tracking info
|
||||||
|
tracking_file = f"experiment_tracking_{int(time.time())}.yaml"
|
||||||
|
with open(tracking_file, 'w') as f:
|
||||||
|
yaml.dump({
|
||||||
|
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
'jobs': submitted_jobs
|
||||||
|
}, f, default_flow_style=False)
|
||||||
|
|
||||||
|
print(f" Job tracking saved to: {tracking_file}")
|
||||||
|
print(f"\\n💡 Monitor progress with:")
|
||||||
|
print(f" squeue -u $USER")
|
||||||
|
print(f" python monitor_experiments.py {tracking_file}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user