Add FastTD3 HoReKa experiment management system

- Fixed JAX/PyTorch dtype mismatch for successful training
- Added experiment plan with paper-accurate hyperparameters
- Created batch submission and monitoring scripts
- Cleaned up log files and updated gitignore
- Ready for systematic paper replication
This commit is contained in:
ys1087@partner.kit.edu 2025-07-22 17:07:20 +02:00
parent 15750f56b2
commit e7e3ae48f1
8 changed files with 543 additions and 15 deletions

22
.gitignore vendored
View File

@ -1,5 +1,23 @@
models
wandb
# Model checkpoints and training outputs
models/
wandb/
# Log files and job outputs
logs/
*.out
*.err
fasttd3_*.out
fasttd3_*.err
# Experiment tracking
experiment_tracking_*.yaml
# Python
*.pyc
__pycache__/
.ipynb_checkpoints
fast_td3.egg-info/
.venv
# SLURM scripts (generated)
scripts/

View File

@ -88,25 +88,26 @@ The setup includes:
- **Test script** (`test_setup.py`) for environment verification
- **MuJoCo Playground environment** (`T1JoystickFlatTerrain`) for humanoid control
- **Automatic GPU detection** and CUDA 12.4 compatibility
- **Wandb logging** with offline mode support
- **Wandb logging** with online mode by default
### Wandb Integration
The scripts support both online and offline wandb logging:
**Online mode:**
**Online mode (default):**
```bash
export WANDB_API_KEY=your_api_key_here
python submit_job.py
# Select 'y' when prompted for online mode
# Select 'y' when prompted for online mode (default)
```
**Offline mode (default):**
**Offline mode:**
```bash
# Jobs run in offline mode by default
# Select 'n' when prompted for online mode
# Sync later with: wandb sync <run_directory>
```
---
# ORIGINAL README:

67
experiment_plan.md Normal file
View File

@ -0,0 +1,67 @@
# FastTD3 HoReKa Experiment Plan
*Added by Dominik - Paper Replication Study*
## ✅ Proof of Concept Results
**Initial Success**: [HoReKa Dev Run](https://wandb.ai/rl-network-scaling/FastTD3_HoReKa_Dev?nw=nwuserdominik_roth)
- **Task**: T1JoystickFlatTerrain
- **Duration**: 7 minutes (5000 timesteps)
- **Performance**: Successfully training at ~29 it/s
- **Key Achievement**: Fixed JAX/PyTorch dtype mismatch issue (removed JAX_ENABLE_X64)
- **Status**: ✅ Environment working, ready for full-scale experiments
## Experiments to Replicate
### Phase 1: MuJoCo Playground (Figure 11 from paper)
- `T1JoystickFlatTerrain` (3600s)
- `T1JoystickRoughTerrain` (3600s)
- `G1JoystickFlatTerrain` (3600s)
- `G1JoystickRoughTerrain` (3600s)
**Hyperparameters (from paper):**
- total_timesteps: 500000
- num_envs: 2048
- batch_size: 32768
- buffer_size: 102400 (50K per env)
- eval_interval: 25000
### Phase 2: IsaacLab (Figure 10 from paper)
- `Isaac-Velocity-Flat-G1-v0` (3600s)
- `Isaac-Velocity-Rough-G1-v0` (3600s)
- `Isaac-Repose-Cube-Allegro-Direct-v0` (3600s)
- `Isaac-Repose-Cube-Shadow-Direct-v0` (3600s)
- `Isaac-Velocity-Flat-H1-v0` (3600s)
- `Isaac-Velocity-Rough-H1-v0` (3600s)
**Hyperparameters:**
- total_timesteps: 1000000
- num_envs: 1024
- batch_size: 32768
- buffer_size: 51200
- eval_interval: 50000
### Phase 3: HumanoidBench (Figure 9 from paper - subset)
- `h1hand-walk` (10800s)
- `h1hand-run` (10800s)
- `h1hand-hurdle` (10800s)
- `h1hand-stair` (10800s)
- `h1hand-slide` (10800s)
**Hyperparameters:**
- total_timesteps: 2000000
- num_envs: 256
- batch_size: 16384
- buffer_size: 12800
- eval_interval: 100000
## Usage
Submit Phase 1:
```bash
python submit_experiment_batch.py --phase 1 --seeds 3
```
Monitor progress:
```bash
python monitor_experiments.py --watch
```

168
monitor_experiments.py Executable file
View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Monitor FastTD3 experiments running on HoReKa cluster.
Usage: python monitor_experiments.py [tracking_file.yaml]
"""
import os
import subprocess
import argparse
import yaml
import time
import requests
from datetime import datetime, timedelta
from pathlib import Path
def get_job_status():
"""Get current SLURM job status for user."""
try:
result = subprocess.run(['squeue', '-u', os.environ['USER'], '--format=%i,%j,%t,%M,%N'],
capture_output=True, text=True, check=True)
jobs = {}
lines = result.stdout.strip().split('\n')
if len(lines) > 1: # Skip header
for line in lines[1:]:
parts = line.split(',')
if len(parts) >= 5:
job_id, name, state, time_used, node = parts
jobs[job_id] = {
'name': name,
'state': state,
'time_used': time_used,
'node': node
}
return jobs
except subprocess.CalledProcessError:
return {}
def check_wandb_run(project, run_name):
"""Check if wandb run exists and get basic stats."""
# Note: This requires wandb API access
# For now, just return placeholder
return {
'exists': True,
'last_logged': datetime.now().isoformat(),
'status': 'running'
}
def format_duration(seconds):
"""Format duration in human readable format."""
if seconds < 60:
return f"{seconds}s"
elif seconds < 3600:
return f"{seconds//60}m {seconds%60}s"
else:
return f"{seconds//3600}h {(seconds%3600)//60}m"
def print_status_table(jobs_status):
"""Print formatted status table."""
print(f"{'Job ID':<10} {'Task':<25} {'Seed':<4} {'State':<12} {'Runtime':<12} {'Node':<10}")
print("=" * 85)
for job_info in jobs_status:
job_id = job_info.get('job_id', 'N/A')
task = job_info.get('task', 'Unknown')[:24]
seed = str(job_info.get('seed', '?'))
slurm_info = job_info.get('slurm_status', {})
state = slurm_info.get('state', 'UNKNOWN')
time_used = slurm_info.get('time_used', 'N/A')
node = slurm_info.get('node', 'N/A')
# Color coding for states
if state == 'RUNNING':
state_color = f"\\033[92m{state}\\033[0m" # Green
elif state == 'PENDING':
state_color = f"\\033[93m{state}\\033[0m" # Yellow
elif state == 'COMPLETED':
state_color = f"\\033[94m{state}\\033[0m" # Blue
elif state in ['FAILED', 'CANCELLED']:
state_color = f"\\033[91m{state}\\033[0m" # Red
else:
state_color = state
print(f"{job_id:<10} {task:<25} {seed:<4} {state_color:<20} {time_used:<12} {node:<10}")
def main():
parser = argparse.ArgumentParser(description='Monitor FastTD3 experiments')
parser.add_argument('tracking_file', nargs='?',
help='YAML tracking file from batch submission')
parser.add_argument('--watch', '-w', action='store_true',
help='Continuously monitor (refresh every 30s)')
parser.add_argument('--summary', '-s', action='store_true',
help='Show summary statistics only')
args = parser.parse_args()
def update_and_display():
os.system('clear' if os.name == 'posix' else 'cls')
print(f"🔍 FastTD3 HoReKa Experiment Monitor")
print(f"📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()
# Get current SLURM status
slurm_jobs = get_job_status()
if args.tracking_file and os.path.exists(args.tracking_file):
# Load tracking file
with open(args.tracking_file, 'r') as f:
tracking_data = yaml.safe_load(f)
jobs_status = []
for job in tracking_data['jobs']:
job_id = job['job_id']
job_info = job.copy()
job_info['slurm_status'] = slurm_jobs.get(job_id, {'state': 'NOT_FOUND'})
jobs_status.append(job_info)
print_status_table(jobs_status)
# Summary statistics
total_jobs = len(jobs_status)
states = {}
for job in jobs_status:
state = job['slurm_status']['state']
states[state] = states.get(state, 0) + 1
print(f"\\n📊 Summary ({total_jobs} jobs):")
for state, count in sorted(states.items()):
percentage = (count / total_jobs) * 100
print(f" {state}: {count} ({percentage:.1f}%)")
else:
# Show all user jobs if no tracking file
if slurm_jobs:
print("📋 All SLURM jobs:")
jobs_list = []
for job_id, info in slurm_jobs.items():
jobs_list.append({
'job_id': job_id,
'task': info['name'],
'seed': '?',
'slurm_status': info
})
print_status_table(jobs_list)
else:
print("✅ No active jobs found")
print(f"\\n💡 Commands:")
print(f" squeue -u $USER # Detailed SLURM status")
print(f" scancel <job_id> # Cancel specific job")
print(f" tail -f logs/*.out # Follow job logs")
print(f" python collect_results.py # Gather completed results")
if args.watch:
try:
while True:
update_and_display()
print(f"\\n⏱ Refreshing in 30s... (Ctrl+C to stop)")
time.sleep(30)
except KeyboardInterrupt:
print("\\n👋 Monitoring stopped")
else:
update_and_display()
if __name__ == "__main__":
main()

View File

@ -24,12 +24,11 @@ source .venv/bin/activate
# Set environment variables for proper GPU usage
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="cuda"
export JAX_ENABLE_X64=True
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
# export WANDB_API_KEY=your_api_key_here
# For testing, use offline mode
export WANDB_MODE=offline
# Use online mode by default - set WANDB_API_KEY before running
export WANDB_MODE=online
# Run FastTD3 training with MuJoCo Playground environment
python fast_td3/train.py \
@ -41,6 +40,6 @@ python fast_td3/train.py \
--batch_size 4096 \
--eval_interval 5000 \
--render_interval 0 \
--project FastTD3_HoReKa
--project FastTD3_HoReKa \
echo "Job completed at $(date)"

View File

@ -24,10 +24,9 @@ source .venv/bin/activate
# Set environment variables for proper GPU usage
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="cuda"
export JAX_ENABLE_X64=True
# For testing, use offline mode
export WANDB_MODE=offline
# Use online mode by default - set WANDB_API_KEY before running
export WANDB_MODE=online
echo "Starting FastTD3 dev test at $(date)"
echo "GPU: $CUDA_VISIBLE_DEVICES"
@ -43,6 +42,6 @@ python fast_td3/train.py \
--batch_size 1024 \
--eval_interval 2500 \
--render_interval 0 \
--project FastTD3_HoReKa_Dev
--project FastTD3_HoReKa_Dev \
echo "Job completed at $(date)"

51
run_fasttd3_full.slurm Normal file
View File

@ -0,0 +1,51 @@
#!/bin/bash
#SBATCH --job-name=fasttd3_full
#SBATCH --account=hk-project-p0022232
#SBATCH --partition=accelerated
#SBATCH --time=04:00:00
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=64G
#SBATCH --output=fasttd3_full_%j.out
#SBATCH --error=fasttd3_full_%j.err
# Load necessary modules
module purge
module load devel/cuda/12.4
module load compiler/intel/2025.1_llvm
# Navigate to the project directory
cd $SLURM_SUBMIT_DIR
# Activate the virtual environment
source .venv/bin/activate
# Set environment variables for proper GPU usage
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="cuda"
# Use online mode by default - set WANDB_API_KEY before running
export WANDB_MODE=online
echo "Starting FastTD3 full training at $(date)"
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "Node: $(hostname)"
echo "Task: ${TASK_NAME:-T1JoystickFlatTerrain}"
echo "Seed: ${SEED:-42}"
echo "Timesteps: ${TOTAL_TIMESTEPS:-500000}"
# Run FastTD3 training with paper-based hyperparameters
python fast_td3/train.py \
--env_name ${TASK_NAME:-T1JoystickFlatTerrain} \
--exp_name FastTD3_Full_${TASK_NAME:-T1JoystickFlatTerrain} \
--seed ${SEED:-42} \
--total_timesteps ${TOTAL_TIMESTEPS:-500000} \
--num_envs ${NUM_ENVS:-2048} \
--batch_size ${BATCH_SIZE:-32768} \
--buffer_size ${BUFFER_SIZE:-102400} \
--eval_interval ${EVAL_INTERVAL:-25000} \
--render_interval 0 \
--project FastTD3_HoReKa_Full
echo "Job completed at $(date)"

225
submit_experiment_batch.py Executable file
View File

@ -0,0 +1,225 @@
#!/usr/bin/env python3
"""
Batch experiment submission script for FastTD3 paper replication.
Usage: python submit_experiment_batch.py --phase 1 --tasks all --seeds 3
"""
import os
import subprocess
import argparse
import time
import yaml
from pathlib import Path
# Experiment configuration based on paper
EXPERIMENT_CONFIG = {
"phase1_mujoco": {
"tasks": [
"T1JoystickFlatTerrain",
"T1JoystickRoughTerrain",
"G1JoystickFlatTerrain",
"G1JoystickRoughTerrain"
],
"total_timesteps": 500000,
"num_envs": 2048,
"batch_size": 32768,
"buffer_size": 102400, # 50K per env for 2048 envs
"eval_interval": 25000,
"time_limit": "04:00:00",
"mem": "64G"
},
"phase2_isaaclab": {
"tasks": [
"Isaac-Velocity-Flat-G1-v0",
"Isaac-Velocity-Rough-G1-v0",
"Isaac-Repose-Cube-Allegro-Direct-v0",
"Isaac-Repose-Cube-Shadow-Direct-v0",
"Isaac-Velocity-Flat-H1-v0",
"Isaac-Velocity-Rough-H1-v0"
],
"total_timesteps": 1000000,
"num_envs": 1024,
"batch_size": 32768,
"buffer_size": 51200,
"eval_interval": 50000,
"time_limit": "04:00:00",
"mem": "64G"
},
"phase3_humanoidbench": {
"tasks": [
"h1hand-walk",
"h1hand-run",
"h1hand-hurdle",
"h1hand-stair",
"h1hand-slide"
],
"total_timesteps": 2000000,
"num_envs": 256,
"batch_size": 16384,
"buffer_size": 12800,
"eval_interval": 100000,
"time_limit": "12:00:00", # 12 hours for HumanoidBench
"mem": "64G"
}
}
def create_job_script(task, config, seed, phase):
"""Create SLURM script for specific task/seed combination."""
script_content = f'''#!/bin/bash
#SBATCH --job-name=fasttd3_{phase}_{task.replace("-", "_")}_s{seed}
#SBATCH --account=hk-project-p0022232
#SBATCH --partition=accelerated
#SBATCH --time={config["time_limit"]}
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem={config["mem"]}
#SBATCH --output=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.out
#SBATCH --error=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.err
# Load necessary modules
module purge
module load devel/cuda/12.4
module load compiler/intel/2025.1_llvm
# Navigate to the project directory
cd $SLURM_SUBMIT_DIR
# Activate the virtual environment
source .venv/bin/activate
# Set environment variables
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
export JAX_PLATFORMS="cuda"
export WANDB_MODE=online
echo "Starting FastTD3 {phase} training at $(date)"
echo "Task: {task}"
echo "Seed: {seed}"
echo "GPU: $CUDA_VISIBLE_DEVICES"
echo "Node: $(hostname)"
# Run FastTD3 training
python fast_td3/train.py \\
--env_name {task} \\
--exp_name FastTD3_{phase}_{task.replace("-", "_")} \\
--seed {seed} \\
--total_timesteps {config["total_timesteps"]} \\
--num_envs {config["num_envs"]} \\
--batch_size {config["batch_size"]} \\
--buffer_size {config["buffer_size"]} \\
--eval_interval {config["eval_interval"]} \\
--render_interval 0 \\
--project FastTD3_HoReKa_{phase.title()}
echo "Job completed at $(date)"
'''
return script_content
def submit_job(script_path, dry_run=False):
"""Submit SLURM job and return job ID."""
if dry_run:
print(f"[DRY RUN] Would submit: {script_path}")
return "12345" # Fake job ID
try:
result = subprocess.run(['sbatch', script_path],
capture_output=True, text=True, check=True)
job_id = result.stdout.strip().split()[-1]
print(f"✅ Submitted {script_path} -> Job ID: {job_id}")
return job_id
except subprocess.CalledProcessError as e:
print(f"❌ Failed to submit {script_path}: {e.stderr}")
return None
def main():
parser = argparse.ArgumentParser(description='Submit FastTD3 experiment batch')
parser.add_argument('--phase', type=str, choices=['1', '2', '3', 'all'],
default='1', help='Experiment phase to run')
parser.add_argument('--tasks', type=str, default='all',
help='Comma-separated task names or "all"')
parser.add_argument('--seeds', type=int, default=3,
help='Number of random seeds to run')
parser.add_argument('--dry-run', action='store_true',
help='Print commands without executing')
parser.add_argument('--delay', type=int, default=5,
help='Delay between job submissions (seconds)')
args = parser.parse_args()
# Create logs directory
os.makedirs('logs', exist_ok=True)
os.makedirs('scripts', exist_ok=True)
# Determine which phases to run
if args.phase == 'all':
phases = ['phase1_mujoco', 'phase2_isaaclab', 'phase3_humanoidbench']
else:
phase_map = {'1': 'phase1_mujoco', '2': 'phase2_isaaclab', '3': 'phase3_humanoidbench'}
phases = [phase_map[args.phase]]
submitted_jobs = []
for phase in phases:
config = EXPERIMENT_CONFIG[phase]
# Determine tasks to run
if args.tasks == 'all':
tasks = config['tasks']
else:
tasks = [t.strip() for t in args.tasks.split(',')]
# Validate tasks exist in config
invalid = set(tasks) - set(config['tasks'])
if invalid:
print(f"❌ Invalid tasks for {phase}: {invalid}")
continue
print(f"\\n🚀 Starting {phase} with tasks: {tasks}")
print(f" Seeds: {list(range(1, args.seeds + 1))}")
for task in tasks:
for seed in range(1, args.seeds + 1):
# Create job script
script_content = create_job_script(task, config, seed, phase)
script_name = f"scripts/fasttd3_{phase}_{task.replace('-', '_')}_s{seed}.slurm"
with open(script_name, 'w') as f:
f.write(script_content)
# Submit job
job_id = submit_job(script_name, args.dry_run)
if job_id:
submitted_jobs.append({
'job_id': job_id,
'phase': phase,
'task': task,
'seed': seed,
'script': script_name
})
# Delay between submissions to avoid overwhelming scheduler
if not args.dry_run and args.delay > 0:
time.sleep(args.delay)
# Summary
print(f"\\n📊 Submission Summary:")
print(f" Total jobs submitted: {len(submitted_jobs)}")
if submitted_jobs:
# Save job tracking info
tracking_file = f"experiment_tracking_{int(time.time())}.yaml"
with open(tracking_file, 'w') as f:
yaml.dump({
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
'jobs': submitted_jobs
}, f, default_flow_style=False)
print(f" Job tracking saved to: {tracking_file}")
print(f"\\n💡 Monitor progress with:")
print(f" squeue -u $USER")
print(f" python monitor_experiments.py {tracking_file}")
if __name__ == "__main__":
main()