- Fixed JAX/PyTorch dtype mismatch for successful training - Added experiment plan with paper-accurate hyperparameters - Created batch submission and monitoring scripts - Cleaned up log files and updated gitignore - Ready for systematic paper replication
168 lines
6.0 KiB
Python
Executable File
168 lines
6.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
Monitor FastTD3 experiments running on HoReKa cluster.
|
||
Usage: python monitor_experiments.py [tracking_file.yaml]
|
||
"""
|
||
|
||
import os
|
||
import subprocess
|
||
import argparse
|
||
import yaml
|
||
import time
|
||
import requests
|
||
from datetime import datetime, timedelta
|
||
from pathlib import Path
|
||
|
||
def get_job_status():
|
||
"""Get current SLURM job status for user."""
|
||
try:
|
||
result = subprocess.run(['squeue', '-u', os.environ['USER'], '--format=%i,%j,%t,%M,%N'],
|
||
capture_output=True, text=True, check=True)
|
||
|
||
jobs = {}
|
||
lines = result.stdout.strip().split('\n')
|
||
if len(lines) > 1: # Skip header
|
||
for line in lines[1:]:
|
||
parts = line.split(',')
|
||
if len(parts) >= 5:
|
||
job_id, name, state, time_used, node = parts
|
||
jobs[job_id] = {
|
||
'name': name,
|
||
'state': state,
|
||
'time_used': time_used,
|
||
'node': node
|
||
}
|
||
return jobs
|
||
except subprocess.CalledProcessError:
|
||
return {}
|
||
|
||
def check_wandb_run(project, run_name):
|
||
"""Check if wandb run exists and get basic stats."""
|
||
# Note: This requires wandb API access
|
||
# For now, just return placeholder
|
||
return {
|
||
'exists': True,
|
||
'last_logged': datetime.now().isoformat(),
|
||
'status': 'running'
|
||
}
|
||
|
||
def format_duration(seconds):
|
||
"""Format duration in human readable format."""
|
||
if seconds < 60:
|
||
return f"{seconds}s"
|
||
elif seconds < 3600:
|
||
return f"{seconds//60}m {seconds%60}s"
|
||
else:
|
||
return f"{seconds//3600}h {(seconds%3600)//60}m"
|
||
|
||
def print_status_table(jobs_status):
|
||
"""Print formatted status table."""
|
||
print(f"{'Job ID':<10} {'Task':<25} {'Seed':<4} {'State':<12} {'Runtime':<12} {'Node':<10}")
|
||
print("=" * 85)
|
||
|
||
for job_info in jobs_status:
|
||
job_id = job_info.get('job_id', 'N/A')
|
||
task = job_info.get('task', 'Unknown')[:24]
|
||
seed = str(job_info.get('seed', '?'))
|
||
|
||
slurm_info = job_info.get('slurm_status', {})
|
||
state = slurm_info.get('state', 'UNKNOWN')
|
||
time_used = slurm_info.get('time_used', 'N/A')
|
||
node = slurm_info.get('node', 'N/A')
|
||
|
||
# Color coding for states
|
||
if state == 'RUNNING':
|
||
state_color = f"\\033[92m{state}\\033[0m" # Green
|
||
elif state == 'PENDING':
|
||
state_color = f"\\033[93m{state}\\033[0m" # Yellow
|
||
elif state == 'COMPLETED':
|
||
state_color = f"\\033[94m{state}\\033[0m" # Blue
|
||
elif state in ['FAILED', 'CANCELLED']:
|
||
state_color = f"\\033[91m{state}\\033[0m" # Red
|
||
else:
|
||
state_color = state
|
||
|
||
print(f"{job_id:<10} {task:<25} {seed:<4} {state_color:<20} {time_used:<12} {node:<10}")
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='Monitor FastTD3 experiments')
|
||
parser.add_argument('tracking_file', nargs='?',
|
||
help='YAML tracking file from batch submission')
|
||
parser.add_argument('--watch', '-w', action='store_true',
|
||
help='Continuously monitor (refresh every 30s)')
|
||
parser.add_argument('--summary', '-s', action='store_true',
|
||
help='Show summary statistics only')
|
||
|
||
args = parser.parse_args()
|
||
|
||
def update_and_display():
|
||
os.system('clear' if os.name == 'posix' else 'cls')
|
||
|
||
print(f"🔍 FastTD3 HoReKa Experiment Monitor")
|
||
print(f"📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
print()
|
||
|
||
# Get current SLURM status
|
||
slurm_jobs = get_job_status()
|
||
|
||
if args.tracking_file and os.path.exists(args.tracking_file):
|
||
# Load tracking file
|
||
with open(args.tracking_file, 'r') as f:
|
||
tracking_data = yaml.safe_load(f)
|
||
|
||
jobs_status = []
|
||
for job in tracking_data['jobs']:
|
||
job_id = job['job_id']
|
||
job_info = job.copy()
|
||
job_info['slurm_status'] = slurm_jobs.get(job_id, {'state': 'NOT_FOUND'})
|
||
jobs_status.append(job_info)
|
||
|
||
print_status_table(jobs_status)
|
||
|
||
# Summary statistics
|
||
total_jobs = len(jobs_status)
|
||
states = {}
|
||
for job in jobs_status:
|
||
state = job['slurm_status']['state']
|
||
states[state] = states.get(state, 0) + 1
|
||
|
||
print(f"\\n📊 Summary ({total_jobs} jobs):")
|
||
for state, count in sorted(states.items()):
|
||
percentage = (count / total_jobs) * 100
|
||
print(f" {state}: {count} ({percentage:.1f}%)")
|
||
|
||
else:
|
||
# Show all user jobs if no tracking file
|
||
if slurm_jobs:
|
||
print("📋 All SLURM jobs:")
|
||
jobs_list = []
|
||
for job_id, info in slurm_jobs.items():
|
||
jobs_list.append({
|
||
'job_id': job_id,
|
||
'task': info['name'],
|
||
'seed': '?',
|
||
'slurm_status': info
|
||
})
|
||
print_status_table(jobs_list)
|
||
else:
|
||
print("✅ No active jobs found")
|
||
|
||
print(f"\\n💡 Commands:")
|
||
print(f" squeue -u $USER # Detailed SLURM status")
|
||
print(f" scancel <job_id> # Cancel specific job")
|
||
print(f" tail -f logs/*.out # Follow job logs")
|
||
print(f" python collect_results.py # Gather completed results")
|
||
|
||
if args.watch:
|
||
try:
|
||
while True:
|
||
update_and_display()
|
||
print(f"\\n⏱️ Refreshing in 30s... (Ctrl+C to stop)")
|
||
time.sleep(30)
|
||
except KeyboardInterrupt:
|
||
print("\\n👋 Monitoring stopped")
|
||
else:
|
||
update_and_display()
|
||
|
||
if __name__ == "__main__":
|
||
main() |