#!/usr/bin/env python3 """ Helper script to submit FastTD3 jobs to SLURM with proper wandb setup. """ import os import subprocess import sys def check_wandb_setup(): """Check if wandb is properly configured.""" try: import wandb # Try to initialize in offline mode to test setup wandb.init(mode="offline") wandb.finish() print("✓ wandb is properly installed") return True except Exception as e: print(f"✗ wandb setup issue: {e}") return False def check_environment(): """Check if we're in the right environment and directory.""" if not os.path.exists('.venv'): print("✗ Virtual environment not found. Run from the FastTD3 directory.") return False if not os.path.exists('fast_td3/train.py'): print("✗ FastTD3 training script not found. Run from the FastTD3 directory.") return False print("✓ Environment looks good") return True def submit_job(script_path="run_fasttd3.slurm", use_wandb_online=False): """Submit the SLURM job.""" if not os.path.exists(script_path): print(f"✗ SLURM script {script_path} not found") return False print(f"Submitting job with script: {script_path}") # If using online wandb, prompt for API key if use_wandb_online: api_key = input("Enter your wandb API key (or press Enter to skip): ").strip() if api_key: # Update the script to set the API key with open(script_path, 'r') as f: content = f.read() # Replace the commented API key line content = content.replace( "# export WANDB_API_KEY=your_api_key_here", f"export WANDB_API_KEY={api_key}" ) # Remove offline mode content = content.replace( "export WANDB_MODE=offline", "# export WANDB_MODE=offline # Using online mode" ) with open(script_path, 'w') as f: f.write(content) print("✓ Updated script with wandb API key") try: result = subprocess.run(['sbatch', script_path], capture_output=True, text=True) if result.returncode == 0: print(f"✓ Job submitted successfully:") print(result.stdout.strip()) job_id = result.stdout.strip().split()[-1] print(f"\nTo monitor the job:") print(f" squeue -j {job_id}") print(f" tail -f fasttd3_{job_id}.out") return True else: print(f"✗ Job submission failed:") print(result.stderr.strip()) return False except FileNotFoundError: print("✗ sbatch command not found. Are you on a SLURM cluster?") return False except Exception as e: print(f"✗ Error submitting job: {e}") return False def main(): print("FastTD3 Job Submission Helper") print("=" * 30) # Check environment if not check_environment(): sys.exit(1) if not check_wandb_setup(): sys.exit(1) # Ask user about wandb mode use_online = input("Use wandb online mode? (y/N): ").lower().startswith('y') # Submit job if submit_job(use_wandb_online=use_online): print("\n🎉 Job submitted successfully!") print("\nTips:") print("- Check job status: squeue -u $USER") print("- View output: tail -f fasttd3_.out") print("- Cancel job: scancel ") if not use_online: print("- Job runs in wandb offline mode. Sync later with: wandb sync ") else: print("\n❌ Job submission failed") sys.exit(1) if __name__ == "__main__": main()