FastTD3/monitor_experiments.py
ys1087@partner.kit.edu e7e3ae48f1 Add FastTD3 HoReKa experiment management system
- Fixed JAX/PyTorch dtype mismatch for successful training
- Added experiment plan with paper-accurate hyperparameters
- Created batch submission and monitoring scripts
- Cleaned up log files and updated gitignore
- Ready for systematic paper replication
2025-07-22 17:08:03 +02:00

168 lines
6.0 KiB
Python
Executable File
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Monitor FastTD3 experiments running on HoReKa cluster.
Usage: python monitor_experiments.py [tracking_file.yaml]
"""
import os
import subprocess
import argparse
import yaml
import time
import requests
from datetime import datetime, timedelta
from pathlib import Path
def get_job_status():
"""Get current SLURM job status for user."""
try:
result = subprocess.run(['squeue', '-u', os.environ['USER'], '--format=%i,%j,%t,%M,%N'],
capture_output=True, text=True, check=True)
jobs = {}
lines = result.stdout.strip().split('\n')
if len(lines) > 1: # Skip header
for line in lines[1:]:
parts = line.split(',')
if len(parts) >= 5:
job_id, name, state, time_used, node = parts
jobs[job_id] = {
'name': name,
'state': state,
'time_used': time_used,
'node': node
}
return jobs
except subprocess.CalledProcessError:
return {}
def check_wandb_run(project, run_name):
"""Check if wandb run exists and get basic stats."""
# Note: This requires wandb API access
# For now, just return placeholder
return {
'exists': True,
'last_logged': datetime.now().isoformat(),
'status': 'running'
}
def format_duration(seconds):
"""Format duration in human readable format."""
if seconds < 60:
return f"{seconds}s"
elif seconds < 3600:
return f"{seconds//60}m {seconds%60}s"
else:
return f"{seconds//3600}h {(seconds%3600)//60}m"
def print_status_table(jobs_status):
"""Print formatted status table."""
print(f"{'Job ID':<10} {'Task':<25} {'Seed':<4} {'State':<12} {'Runtime':<12} {'Node':<10}")
print("=" * 85)
for job_info in jobs_status:
job_id = job_info.get('job_id', 'N/A')
task = job_info.get('task', 'Unknown')[:24]
seed = str(job_info.get('seed', '?'))
slurm_info = job_info.get('slurm_status', {})
state = slurm_info.get('state', 'UNKNOWN')
time_used = slurm_info.get('time_used', 'N/A')
node = slurm_info.get('node', 'N/A')
# Color coding for states
if state == 'RUNNING':
state_color = f"\\033[92m{state}\\033[0m" # Green
elif state == 'PENDING':
state_color = f"\\033[93m{state}\\033[0m" # Yellow
elif state == 'COMPLETED':
state_color = f"\\033[94m{state}\\033[0m" # Blue
elif state in ['FAILED', 'CANCELLED']:
state_color = f"\\033[91m{state}\\033[0m" # Red
else:
state_color = state
print(f"{job_id:<10} {task:<25} {seed:<4} {state_color:<20} {time_used:<12} {node:<10}")
def main():
parser = argparse.ArgumentParser(description='Monitor FastTD3 experiments')
parser.add_argument('tracking_file', nargs='?',
help='YAML tracking file from batch submission')
parser.add_argument('--watch', '-w', action='store_true',
help='Continuously monitor (refresh every 30s)')
parser.add_argument('--summary', '-s', action='store_true',
help='Show summary statistics only')
args = parser.parse_args()
def update_and_display():
os.system('clear' if os.name == 'posix' else 'cls')
print(f"🔍 FastTD3 HoReKa Experiment Monitor")
print(f"📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()
# Get current SLURM status
slurm_jobs = get_job_status()
if args.tracking_file and os.path.exists(args.tracking_file):
# Load tracking file
with open(args.tracking_file, 'r') as f:
tracking_data = yaml.safe_load(f)
jobs_status = []
for job in tracking_data['jobs']:
job_id = job['job_id']
job_info = job.copy()
job_info['slurm_status'] = slurm_jobs.get(job_id, {'state': 'NOT_FOUND'})
jobs_status.append(job_info)
print_status_table(jobs_status)
# Summary statistics
total_jobs = len(jobs_status)
states = {}
for job in jobs_status:
state = job['slurm_status']['state']
states[state] = states.get(state, 0) + 1
print(f"\\n📊 Summary ({total_jobs} jobs):")
for state, count in sorted(states.items()):
percentage = (count / total_jobs) * 100
print(f" {state}: {count} ({percentage:.1f}%)")
else:
# Show all user jobs if no tracking file
if slurm_jobs:
print("📋 All SLURM jobs:")
jobs_list = []
for job_id, info in slurm_jobs.items():
jobs_list.append({
'job_id': job_id,
'task': info['name'],
'seed': '?',
'slurm_status': info
})
print_status_table(jobs_list)
else:
print("✅ No active jobs found")
print(f"\\n💡 Commands:")
print(f" squeue -u $USER # Detailed SLURM status")
print(f" scancel <job_id> # Cancel specific job")
print(f" tail -f logs/*.out # Follow job logs")
print(f" python collect_results.py # Gather completed results")
if args.watch:
try:
while True:
update_and_display()
print(f"\\n⏱ Refreshing in 30s... (Ctrl+C to stop)")
time.sleep(30)
except KeyboardInterrupt:
print("\\n👋 Monitoring stopped")
else:
update_and_display()
if __name__ == "__main__":
main()