Add FastTD3 HoReKa experiment management system
- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
This commit is contained in:
parent
15750f56b2
commit
e7e3ae48f1
22
.gitignore
vendored
22
.gitignore
vendored
@ -1,5 +1,23 @@
|
||||
models
|
||||
wandb
|
||||
# Model checkpoints and training outputs
|
||||
models/
|
||||
wandb/
|
||||
|
||||
# Log files and job outputs
|
||||
logs/
|
||||
*.out
|
||||
*.err
|
||||
fasttd3_*.out
|
||||
fasttd3_*.err
|
||||
|
||||
# Experiment tracking
|
||||
experiment_tracking_*.yaml
|
||||
|
||||
# Python
|
||||
*.pyc
|
||||
__pycache__/
|
||||
.ipynb_checkpoints
|
||||
fast_td3.egg-info/
|
||||
.venv
|
||||
|
||||
# SLURM scripts (generated)
|
||||
scripts/
|
||||
|
11
README.md
11
README.md
@ -88,25 +88,26 @@ The setup includes:
|
||||
- **Test script** (`test_setup.py`) for environment verification
|
||||
- **MuJoCo Playground environment** (`T1JoystickFlatTerrain`) for humanoid control
|
||||
- **Automatic GPU detection** and CUDA 12.4 compatibility
|
||||
- **Wandb logging** with offline mode support
|
||||
- **Wandb logging** with online mode by default
|
||||
|
||||
### Wandb Integration
|
||||
|
||||
The scripts support both online and offline wandb logging:
|
||||
|
||||
**Online mode:**
|
||||
**Online mode (default):**
|
||||
```bash
|
||||
export WANDB_API_KEY=your_api_key_here
|
||||
python submit_job.py
|
||||
# Select 'y' when prompted for online mode
|
||||
# Select 'y' when prompted for online mode (default)
|
||||
```
|
||||
|
||||
**Offline mode (default):**
|
||||
**Offline mode:**
|
||||
```bash
|
||||
# Jobs run in offline mode by default
|
||||
# Select 'n' when prompted for online mode
|
||||
# Sync later with: wandb sync <run_directory>
|
||||
```
|
||||
|
||||
|
||||
---
|
||||
|
||||
# ORIGINAL README:
|
||||
|
67
experiment_plan.md
Normal file
67
experiment_plan.md
Normal file
@ -0,0 +1,67 @@
|
||||
# FastTD3 HoReKa Experiment Plan
|
||||
*Added by Dominik - Paper Replication Study*
|
||||
|
||||
## ✅ Proof of Concept Results
|
||||
**Initial Success**: [HoReKa Dev Run](https://wandb.ai/rl-network-scaling/FastTD3_HoReKa_Dev?nw=nwuserdominik_roth)
|
||||
|
||||
- **Task**: T1JoystickFlatTerrain
|
||||
- **Duration**: 7 minutes (5000 timesteps)
|
||||
- **Performance**: Successfully training at ~29 it/s
|
||||
- **Key Achievement**: Fixed JAX/PyTorch dtype mismatch issue (removed JAX_ENABLE_X64)
|
||||
- **Status**: ✅ Environment working, ready for full-scale experiments
|
||||
|
||||
## Experiments to Replicate
|
||||
|
||||
### Phase 1: MuJoCo Playground (Figure 11 from paper)
|
||||
- `T1JoystickFlatTerrain` (3600s)
|
||||
- `T1JoystickRoughTerrain` (3600s)
|
||||
- `G1JoystickFlatTerrain` (3600s)
|
||||
- `G1JoystickRoughTerrain` (3600s)
|
||||
|
||||
**Hyperparameters (from paper):**
|
||||
- total_timesteps: 500000
|
||||
- num_envs: 2048
|
||||
- batch_size: 32768
|
||||
- buffer_size: 102400 (50K per env)
|
||||
- eval_interval: 25000
|
||||
|
||||
### Phase 2: IsaacLab (Figure 10 from paper)
|
||||
- `Isaac-Velocity-Flat-G1-v0` (3600s)
|
||||
- `Isaac-Velocity-Rough-G1-v0` (3600s)
|
||||
- `Isaac-Repose-Cube-Allegro-Direct-v0` (3600s)
|
||||
- `Isaac-Repose-Cube-Shadow-Direct-v0` (3600s)
|
||||
- `Isaac-Velocity-Flat-H1-v0` (3600s)
|
||||
- `Isaac-Velocity-Rough-H1-v0` (3600s)
|
||||
|
||||
**Hyperparameters:**
|
||||
- total_timesteps: 1000000
|
||||
- num_envs: 1024
|
||||
- batch_size: 32768
|
||||
- buffer_size: 51200
|
||||
- eval_interval: 50000
|
||||
|
||||
### Phase 3: HumanoidBench (Figure 9 from paper - subset)
|
||||
- `h1hand-walk` (10800s)
|
||||
- `h1hand-run` (10800s)
|
||||
- `h1hand-hurdle` (10800s)
|
||||
- `h1hand-stair` (10800s)
|
||||
- `h1hand-slide` (10800s)
|
||||
|
||||
**Hyperparameters:**
|
||||
- total_timesteps: 2000000
|
||||
- num_envs: 256
|
||||
- batch_size: 16384
|
||||
- buffer_size: 12800
|
||||
- eval_interval: 100000
|
||||
|
||||
## Usage
|
||||
|
||||
Submit Phase 1:
|
||||
```bash
|
||||
python submit_experiment_batch.py --phase 1 --seeds 3
|
||||
```
|
||||
|
||||
Monitor progress:
|
||||
```bash
|
||||
python monitor_experiments.py --watch
|
||||
```
|
168
monitor_experiments.py
Executable file
168
monitor_experiments.py
Executable file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Monitor FastTD3 experiments running on HoReKa cluster.
|
||||
Usage: python monitor_experiments.py [tracking_file.yaml]
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
import yaml
|
||||
import time
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
def get_job_status():
|
||||
"""Get current SLURM job status for user."""
|
||||
try:
|
||||
result = subprocess.run(['squeue', '-u', os.environ['USER'], '--format=%i,%j,%t,%M,%N'],
|
||||
capture_output=True, text=True, check=True)
|
||||
|
||||
jobs = {}
|
||||
lines = result.stdout.strip().split('\n')
|
||||
if len(lines) > 1: # Skip header
|
||||
for line in lines[1:]:
|
||||
parts = line.split(',')
|
||||
if len(parts) >= 5:
|
||||
job_id, name, state, time_used, node = parts
|
||||
jobs[job_id] = {
|
||||
'name': name,
|
||||
'state': state,
|
||||
'time_used': time_used,
|
||||
'node': node
|
||||
}
|
||||
return jobs
|
||||
except subprocess.CalledProcessError:
|
||||
return {}
|
||||
|
||||
def check_wandb_run(project, run_name):
|
||||
"""Check if wandb run exists and get basic stats."""
|
||||
# Note: This requires wandb API access
|
||||
# For now, just return placeholder
|
||||
return {
|
||||
'exists': True,
|
||||
'last_logged': datetime.now().isoformat(),
|
||||
'status': 'running'
|
||||
}
|
||||
|
||||
def format_duration(seconds):
|
||||
"""Format duration in human readable format."""
|
||||
if seconds < 60:
|
||||
return f"{seconds}s"
|
||||
elif seconds < 3600:
|
||||
return f"{seconds//60}m {seconds%60}s"
|
||||
else:
|
||||
return f"{seconds//3600}h {(seconds%3600)//60}m"
|
||||
|
||||
def print_status_table(jobs_status):
|
||||
"""Print formatted status table."""
|
||||
print(f"{'Job ID':<10} {'Task':<25} {'Seed':<4} {'State':<12} {'Runtime':<12} {'Node':<10}")
|
||||
print("=" * 85)
|
||||
|
||||
for job_info in jobs_status:
|
||||
job_id = job_info.get('job_id', 'N/A')
|
||||
task = job_info.get('task', 'Unknown')[:24]
|
||||
seed = str(job_info.get('seed', '?'))
|
||||
|
||||
slurm_info = job_info.get('slurm_status', {})
|
||||
state = slurm_info.get('state', 'UNKNOWN')
|
||||
time_used = slurm_info.get('time_used', 'N/A')
|
||||
node = slurm_info.get('node', 'N/A')
|
||||
|
||||
# Color coding for states
|
||||
if state == 'RUNNING':
|
||||
state_color = f"\\033[92m{state}\\033[0m" # Green
|
||||
elif state == 'PENDING':
|
||||
state_color = f"\\033[93m{state}\\033[0m" # Yellow
|
||||
elif state == 'COMPLETED':
|
||||
state_color = f"\\033[94m{state}\\033[0m" # Blue
|
||||
elif state in ['FAILED', 'CANCELLED']:
|
||||
state_color = f"\\033[91m{state}\\033[0m" # Red
|
||||
else:
|
||||
state_color = state
|
||||
|
||||
print(f"{job_id:<10} {task:<25} {seed:<4} {state_color:<20} {time_used:<12} {node:<10}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Monitor FastTD3 experiments')
|
||||
parser.add_argument('tracking_file', nargs='?',
|
||||
help='YAML tracking file from batch submission')
|
||||
parser.add_argument('--watch', '-w', action='store_true',
|
||||
help='Continuously monitor (refresh every 30s)')
|
||||
parser.add_argument('--summary', '-s', action='store_true',
|
||||
help='Show summary statistics only')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def update_and_display():
|
||||
os.system('clear' if os.name == 'posix' else 'cls')
|
||||
|
||||
print(f"🔍 FastTD3 HoReKa Experiment Monitor")
|
||||
print(f"📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print()
|
||||
|
||||
# Get current SLURM status
|
||||
slurm_jobs = get_job_status()
|
||||
|
||||
if args.tracking_file and os.path.exists(args.tracking_file):
|
||||
# Load tracking file
|
||||
with open(args.tracking_file, 'r') as f:
|
||||
tracking_data = yaml.safe_load(f)
|
||||
|
||||
jobs_status = []
|
||||
for job in tracking_data['jobs']:
|
||||
job_id = job['job_id']
|
||||
job_info = job.copy()
|
||||
job_info['slurm_status'] = slurm_jobs.get(job_id, {'state': 'NOT_FOUND'})
|
||||
jobs_status.append(job_info)
|
||||
|
||||
print_status_table(jobs_status)
|
||||
|
||||
# Summary statistics
|
||||
total_jobs = len(jobs_status)
|
||||
states = {}
|
||||
for job in jobs_status:
|
||||
state = job['slurm_status']['state']
|
||||
states[state] = states.get(state, 0) + 1
|
||||
|
||||
print(f"\\n📊 Summary ({total_jobs} jobs):")
|
||||
for state, count in sorted(states.items()):
|
||||
percentage = (count / total_jobs) * 100
|
||||
print(f" {state}: {count} ({percentage:.1f}%)")
|
||||
|
||||
else:
|
||||
# Show all user jobs if no tracking file
|
||||
if slurm_jobs:
|
||||
print("📋 All SLURM jobs:")
|
||||
jobs_list = []
|
||||
for job_id, info in slurm_jobs.items():
|
||||
jobs_list.append({
|
||||
'job_id': job_id,
|
||||
'task': info['name'],
|
||||
'seed': '?',
|
||||
'slurm_status': info
|
||||
})
|
||||
print_status_table(jobs_list)
|
||||
else:
|
||||
print("✅ No active jobs found")
|
||||
|
||||
print(f"\\n💡 Commands:")
|
||||
print(f" squeue -u $USER # Detailed SLURM status")
|
||||
print(f" scancel <job_id> # Cancel specific job")
|
||||
print(f" tail -f logs/*.out # Follow job logs")
|
||||
print(f" python collect_results.py # Gather completed results")
|
||||
|
||||
if args.watch:
|
||||
try:
|
||||
while True:
|
||||
update_and_display()
|
||||
print(f"\\n⏱️ Refreshing in 30s... (Ctrl+C to stop)")
|
||||
time.sleep(30)
|
||||
except KeyboardInterrupt:
|
||||
print("\\n👋 Monitoring stopped")
|
||||
else:
|
||||
update_and_display()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -24,12 +24,11 @@ source .venv/bin/activate
|
||||
# Set environment variables for proper GPU usage
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="cuda"
|
||||
export JAX_ENABLE_X64=True
|
||||
|
||||
# Ensure wandb is logged in (set WANDB_API_KEY environment variable)
|
||||
# export WANDB_API_KEY=your_api_key_here
|
||||
# For testing, use offline mode
|
||||
export WANDB_MODE=offline
|
||||
# Use online mode by default - set WANDB_API_KEY before running
|
||||
export WANDB_MODE=online
|
||||
|
||||
# Run FastTD3 training with MuJoCo Playground environment
|
||||
python fast_td3/train.py \
|
||||
@ -41,6 +40,6 @@ python fast_td3/train.py \
|
||||
--batch_size 4096 \
|
||||
--eval_interval 5000 \
|
||||
--render_interval 0 \
|
||||
--project FastTD3_HoReKa
|
||||
--project FastTD3_HoReKa \
|
||||
|
||||
echo "Job completed at $(date)"
|
@ -24,10 +24,9 @@ source .venv/bin/activate
|
||||
# Set environment variables for proper GPU usage
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="cuda"
|
||||
export JAX_ENABLE_X64=True
|
||||
|
||||
# For testing, use offline mode
|
||||
export WANDB_MODE=offline
|
||||
# Use online mode by default - set WANDB_API_KEY before running
|
||||
export WANDB_MODE=online
|
||||
|
||||
echo "Starting FastTD3 dev test at $(date)"
|
||||
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||
@ -43,6 +42,6 @@ python fast_td3/train.py \
|
||||
--batch_size 1024 \
|
||||
--eval_interval 2500 \
|
||||
--render_interval 0 \
|
||||
--project FastTD3_HoReKa_Dev
|
||||
--project FastTD3_HoReKa_Dev \
|
||||
|
||||
echo "Job completed at $(date)"
|
51
run_fasttd3_full.slurm
Normal file
51
run_fasttd3_full.slurm
Normal file
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=fasttd3_full
|
||||
#SBATCH --account=hk-project-p0022232
|
||||
#SBATCH --partition=accelerated
|
||||
#SBATCH --time=04:00:00
|
||||
#SBATCH --gres=gpu:1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH --output=fasttd3_full_%j.out
|
||||
#SBATCH --error=fasttd3_full_%j.err
|
||||
|
||||
# Load necessary modules
|
||||
module purge
|
||||
module load devel/cuda/12.4
|
||||
module load compiler/intel/2025.1_llvm
|
||||
|
||||
# Navigate to the project directory
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
|
||||
# Activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set environment variables for proper GPU usage
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="cuda"
|
||||
|
||||
# Use online mode by default - set WANDB_API_KEY before running
|
||||
export WANDB_MODE=online
|
||||
|
||||
echo "Starting FastTD3 full training at $(date)"
|
||||
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||
echo "Node: $(hostname)"
|
||||
echo "Task: ${TASK_NAME:-T1JoystickFlatTerrain}"
|
||||
echo "Seed: ${SEED:-42}"
|
||||
echo "Timesteps: ${TOTAL_TIMESTEPS:-500000}"
|
||||
|
||||
# Run FastTD3 training with paper-based hyperparameters
|
||||
python fast_td3/train.py \
|
||||
--env_name ${TASK_NAME:-T1JoystickFlatTerrain} \
|
||||
--exp_name FastTD3_Full_${TASK_NAME:-T1JoystickFlatTerrain} \
|
||||
--seed ${SEED:-42} \
|
||||
--total_timesteps ${TOTAL_TIMESTEPS:-500000} \
|
||||
--num_envs ${NUM_ENVS:-2048} \
|
||||
--batch_size ${BATCH_SIZE:-32768} \
|
||||
--buffer_size ${BUFFER_SIZE:-102400} \
|
||||
--eval_interval ${EVAL_INTERVAL:-25000} \
|
||||
--render_interval 0 \
|
||||
--project FastTD3_HoReKa_Full
|
||||
|
||||
echo "Job completed at $(date)"
|
225
submit_experiment_batch.py
Executable file
225
submit_experiment_batch.py
Executable file
@ -0,0 +1,225 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch experiment submission script for FastTD3 paper replication.
|
||||
Usage: python submit_experiment_batch.py --phase 1 --tasks all --seeds 3
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
import time
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
# Experiment configuration based on paper
|
||||
EXPERIMENT_CONFIG = {
|
||||
"phase1_mujoco": {
|
||||
"tasks": [
|
||||
"T1JoystickFlatTerrain",
|
||||
"T1JoystickRoughTerrain",
|
||||
"G1JoystickFlatTerrain",
|
||||
"G1JoystickRoughTerrain"
|
||||
],
|
||||
"total_timesteps": 500000,
|
||||
"num_envs": 2048,
|
||||
"batch_size": 32768,
|
||||
"buffer_size": 102400, # 50K per env for 2048 envs
|
||||
"eval_interval": 25000,
|
||||
"time_limit": "04:00:00",
|
||||
"mem": "64G"
|
||||
},
|
||||
"phase2_isaaclab": {
|
||||
"tasks": [
|
||||
"Isaac-Velocity-Flat-G1-v0",
|
||||
"Isaac-Velocity-Rough-G1-v0",
|
||||
"Isaac-Repose-Cube-Allegro-Direct-v0",
|
||||
"Isaac-Repose-Cube-Shadow-Direct-v0",
|
||||
"Isaac-Velocity-Flat-H1-v0",
|
||||
"Isaac-Velocity-Rough-H1-v0"
|
||||
],
|
||||
"total_timesteps": 1000000,
|
||||
"num_envs": 1024,
|
||||
"batch_size": 32768,
|
||||
"buffer_size": 51200,
|
||||
"eval_interval": 50000,
|
||||
"time_limit": "04:00:00",
|
||||
"mem": "64G"
|
||||
},
|
||||
"phase3_humanoidbench": {
|
||||
"tasks": [
|
||||
"h1hand-walk",
|
||||
"h1hand-run",
|
||||
"h1hand-hurdle",
|
||||
"h1hand-stair",
|
||||
"h1hand-slide"
|
||||
],
|
||||
"total_timesteps": 2000000,
|
||||
"num_envs": 256,
|
||||
"batch_size": 16384,
|
||||
"buffer_size": 12800,
|
||||
"eval_interval": 100000,
|
||||
"time_limit": "12:00:00", # 12 hours for HumanoidBench
|
||||
"mem": "64G"
|
||||
}
|
||||
}
|
||||
|
||||
def create_job_script(task, config, seed, phase):
|
||||
"""Create SLURM script for specific task/seed combination."""
|
||||
|
||||
script_content = f'''#!/bin/bash
|
||||
#SBATCH --job-name=fasttd3_{phase}_{task.replace("-", "_")}_s{seed}
|
||||
#SBATCH --account=hk-project-p0022232
|
||||
#SBATCH --partition=accelerated
|
||||
#SBATCH --time={config["time_limit"]}
|
||||
#SBATCH --gres=gpu:1
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --mem={config["mem"]}
|
||||
#SBATCH --output=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.out
|
||||
#SBATCH --error=logs/fasttd3_{phase}_{task.replace("-", "_")}_s{seed}_%j.err
|
||||
|
||||
# Load necessary modules
|
||||
module purge
|
||||
module load devel/cuda/12.4
|
||||
module load compiler/intel/2025.1_llvm
|
||||
|
||||
# Navigate to the project directory
|
||||
cd $SLURM_SUBMIT_DIR
|
||||
|
||||
# Activate the virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Set environment variables
|
||||
export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
|
||||
export JAX_PLATFORMS="cuda"
|
||||
export WANDB_MODE=online
|
||||
|
||||
echo "Starting FastTD3 {phase} training at $(date)"
|
||||
echo "Task: {task}"
|
||||
echo "Seed: {seed}"
|
||||
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
||||
echo "Node: $(hostname)"
|
||||
|
||||
# Run FastTD3 training
|
||||
python fast_td3/train.py \\
|
||||
--env_name {task} \\
|
||||
--exp_name FastTD3_{phase}_{task.replace("-", "_")} \\
|
||||
--seed {seed} \\
|
||||
--total_timesteps {config["total_timesteps"]} \\
|
||||
--num_envs {config["num_envs"]} \\
|
||||
--batch_size {config["batch_size"]} \\
|
||||
--buffer_size {config["buffer_size"]} \\
|
||||
--eval_interval {config["eval_interval"]} \\
|
||||
--render_interval 0 \\
|
||||
--project FastTD3_HoReKa_{phase.title()}
|
||||
|
||||
echo "Job completed at $(date)"
|
||||
'''
|
||||
|
||||
return script_content
|
||||
|
||||
def submit_job(script_path, dry_run=False):
|
||||
"""Submit SLURM job and return job ID."""
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] Would submit: {script_path}")
|
||||
return "12345" # Fake job ID
|
||||
|
||||
try:
|
||||
result = subprocess.run(['sbatch', script_path],
|
||||
capture_output=True, text=True, check=True)
|
||||
job_id = result.stdout.strip().split()[-1]
|
||||
print(f"✅ Submitted {script_path} -> Job ID: {job_id}")
|
||||
return job_id
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ Failed to submit {script_path}: {e.stderr}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Submit FastTD3 experiment batch')
|
||||
parser.add_argument('--phase', type=str, choices=['1', '2', '3', 'all'],
|
||||
default='1', help='Experiment phase to run')
|
||||
parser.add_argument('--tasks', type=str, default='all',
|
||||
help='Comma-separated task names or "all"')
|
||||
parser.add_argument('--seeds', type=int, default=3,
|
||||
help='Number of random seeds to run')
|
||||
parser.add_argument('--dry-run', action='store_true',
|
||||
help='Print commands without executing')
|
||||
parser.add_argument('--delay', type=int, default=5,
|
||||
help='Delay between job submissions (seconds)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create logs directory
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
os.makedirs('scripts', exist_ok=True)
|
||||
|
||||
# Determine which phases to run
|
||||
if args.phase == 'all':
|
||||
phases = ['phase1_mujoco', 'phase2_isaaclab', 'phase3_humanoidbench']
|
||||
else:
|
||||
phase_map = {'1': 'phase1_mujoco', '2': 'phase2_isaaclab', '3': 'phase3_humanoidbench'}
|
||||
phases = [phase_map[args.phase]]
|
||||
|
||||
submitted_jobs = []
|
||||
|
||||
for phase in phases:
|
||||
config = EXPERIMENT_CONFIG[phase]
|
||||
|
||||
# Determine tasks to run
|
||||
if args.tasks == 'all':
|
||||
tasks = config['tasks']
|
||||
else:
|
||||
tasks = [t.strip() for t in args.tasks.split(',')]
|
||||
# Validate tasks exist in config
|
||||
invalid = set(tasks) - set(config['tasks'])
|
||||
if invalid:
|
||||
print(f"❌ Invalid tasks for {phase}: {invalid}")
|
||||
continue
|
||||
|
||||
print(f"\\n🚀 Starting {phase} with tasks: {tasks}")
|
||||
print(f" Seeds: {list(range(1, args.seeds + 1))}")
|
||||
|
||||
for task in tasks:
|
||||
for seed in range(1, args.seeds + 1):
|
||||
# Create job script
|
||||
script_content = create_job_script(task, config, seed, phase)
|
||||
script_name = f"scripts/fasttd3_{phase}_{task.replace('-', '_')}_s{seed}.slurm"
|
||||
|
||||
with open(script_name, 'w') as f:
|
||||
f.write(script_content)
|
||||
|
||||
# Submit job
|
||||
job_id = submit_job(script_name, args.dry_run)
|
||||
if job_id:
|
||||
submitted_jobs.append({
|
||||
'job_id': job_id,
|
||||
'phase': phase,
|
||||
'task': task,
|
||||
'seed': seed,
|
||||
'script': script_name
|
||||
})
|
||||
|
||||
# Delay between submissions to avoid overwhelming scheduler
|
||||
if not args.dry_run and args.delay > 0:
|
||||
time.sleep(args.delay)
|
||||
|
||||
# Summary
|
||||
print(f"\\n📊 Submission Summary:")
|
||||
print(f" Total jobs submitted: {len(submitted_jobs)}")
|
||||
|
||||
if submitted_jobs:
|
||||
# Save job tracking info
|
||||
tracking_file = f"experiment_tracking_{int(time.time())}.yaml"
|
||||
with open(tracking_file, 'w') as f:
|
||||
yaml.dump({
|
||||
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'jobs': submitted_jobs
|
||||
}, f, default_flow_style=False)
|
||||
|
||||
print(f" Job tracking saved to: {tracking_file}")
|
||||
print(f"\\n💡 Monitor progress with:")
|
||||
print(f" squeue -u $USER")
|
||||
print(f" python monitor_experiments.py {tracking_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user