- Add complete HoReKa installation guide without conda dependency - Include SLURM job script with GPU configuration and account setup - Add helper scripts for job submission and environment testing - Integrate wandb logging with both online and offline modes - Support MuJoCo Playground environments for humanoid control - Update README with clear separation of added vs original content
116 lines
3.8 KiB
Python
Executable File
116 lines
3.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Helper script to submit FastTD3 jobs to SLURM with proper wandb setup.
|
|
"""
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
def check_wandb_setup():
|
|
"""Check if wandb is properly configured."""
|
|
try:
|
|
import wandb
|
|
# Try to initialize in offline mode to test setup
|
|
wandb.init(mode="offline")
|
|
wandb.finish()
|
|
print("✓ wandb is properly installed")
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ wandb setup issue: {e}")
|
|
return False
|
|
|
|
def check_environment():
|
|
"""Check if we're in the right environment and directory."""
|
|
if not os.path.exists('.venv'):
|
|
print("✗ Virtual environment not found. Run from the FastTD3 directory.")
|
|
return False
|
|
|
|
if not os.path.exists('fast_td3/train.py'):
|
|
print("✗ FastTD3 training script not found. Run from the FastTD3 directory.")
|
|
return False
|
|
|
|
print("✓ Environment looks good")
|
|
return True
|
|
|
|
def submit_job(script_path="run_fasttd3.slurm", use_wandb_online=False):
|
|
"""Submit the SLURM job."""
|
|
if not os.path.exists(script_path):
|
|
print(f"✗ SLURM script {script_path} not found")
|
|
return False
|
|
|
|
print(f"Submitting job with script: {script_path}")
|
|
|
|
# If using online wandb, prompt for API key
|
|
if use_wandb_online:
|
|
api_key = input("Enter your wandb API key (or press Enter to skip): ").strip()
|
|
if api_key:
|
|
# Update the script to set the API key
|
|
with open(script_path, 'r') as f:
|
|
content = f.read()
|
|
|
|
# Replace the commented API key line
|
|
content = content.replace(
|
|
"# export WANDB_API_KEY=your_api_key_here",
|
|
f"export WANDB_API_KEY={api_key}"
|
|
)
|
|
# Remove offline mode
|
|
content = content.replace(
|
|
"export WANDB_MODE=offline",
|
|
"# export WANDB_MODE=offline # Using online mode"
|
|
)
|
|
|
|
with open(script_path, 'w') as f:
|
|
f.write(content)
|
|
print("✓ Updated script with wandb API key")
|
|
|
|
try:
|
|
result = subprocess.run(['sbatch', script_path],
|
|
capture_output=True, text=True)
|
|
if result.returncode == 0:
|
|
print(f"✓ Job submitted successfully:")
|
|
print(result.stdout.strip())
|
|
job_id = result.stdout.strip().split()[-1]
|
|
print(f"\nTo monitor the job:")
|
|
print(f" squeue -j {job_id}")
|
|
print(f" tail -f fasttd3_{job_id}.out")
|
|
return True
|
|
else:
|
|
print(f"✗ Job submission failed:")
|
|
print(result.stderr.strip())
|
|
return False
|
|
except FileNotFoundError:
|
|
print("✗ sbatch command not found. Are you on a SLURM cluster?")
|
|
return False
|
|
except Exception as e:
|
|
print(f"✗ Error submitting job: {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("FastTD3 Job Submission Helper")
|
|
print("=" * 30)
|
|
|
|
# Check environment
|
|
if not check_environment():
|
|
sys.exit(1)
|
|
|
|
if not check_wandb_setup():
|
|
sys.exit(1)
|
|
|
|
# Ask user about wandb mode
|
|
use_online = input("Use wandb online mode? (y/N): ").lower().startswith('y')
|
|
|
|
# Submit job
|
|
if submit_job(use_wandb_online=use_online):
|
|
print("\n🎉 Job submitted successfully!")
|
|
print("\nTips:")
|
|
print("- Check job status: squeue -u $USER")
|
|
print("- View output: tail -f fasttd3_<jobid>.out")
|
|
print("- Cancel job: scancel <jobid>")
|
|
if not use_online:
|
|
print("- Job runs in wandb offline mode. Sync later with: wandb sync <run_dir>")
|
|
else:
|
|
print("\n❌ Job submission failed")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main() |