reppo/submit_experiments.py
ys1087@partner.kit.edu 1caaa9d01f Add experiment infrastructure and production scripts
- Fix 6 critical bugs in original REPPO repository
- Add comprehensive README documentation
- Create production SLURM script for accelerated partition
- Add experiment submission script for batch jobs
- Algorithm now runs successfully with strong performance
- Ready for paper replication experiments on Brax suite
2025-07-22 18:47:43 +02:00

65 lines
2.1 KiB
Python
Executable File

#!/usr/bin/env python3
"""Submit REPPO experiments to replicate paper results"""
import subprocess
import os
def submit_job(env, env_name, seed, wandb_project="reppo_paper_replication"):
"""Submit single job using existing SLURM script"""
# Create logs directory
os.makedirs('logs', exist_ok=True)
# Submit using our working dev script as template
result = subprocess.run([
'sbatch',
'--job-name', f'reppo_{env_name}_{seed}',
'--output', f'logs/reppo_{env_name}_{seed}_%j.out',
'--error', f'logs/reppo_{env_name}_{seed}_%j.err',
'--export', f'ENV_NAME={env_name},SEED={seed}',
'slurm/run_reppo_prod.sh'
], capture_output=True, text=True)
if result.returncode == 0:
job_id = result.stdout.strip().split()[-1]
print(f"{env_name} seed={seed}: {job_id}")
return job_id
else:
print(f"{env_name} seed={seed}: {result.stderr}")
return None
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--experiment', choices=['brax', 'mjx', 'maniskill'], required=True)
parser.add_argument('--seeds', type=int, default=5)
parser.add_argument('--dry_run', action='store_true')
args = parser.parse_args()
if args.experiment == 'brax':
envs = ['ant', 'cheetah', 'humanoid', 'walker', 'hopper']
elif args.experiment == 'mjx':
envs = ['CheetahRun', 'FingerSpin', 'HumanoidRun', 'WalkerRun'] # DMC names
elif args.experiment == 'maniskill':
envs = ['PickSingleYCB-v1', 'PegInsertionSide-v1', 'UnitreeG1TransportBox-v1', 'RollBall-v1']
print(f"Submitting {args.experiment} experiments")
print(f"Environments: {envs}")
print(f"Seeds: {args.seeds}")
if args.dry_run:
print("DRY RUN - not submitting")
return
job_count = 0
for env_name in envs:
for seed in range(args.seeds):
submit_job(args.experiment, env_name, seed)
job_count += 1
print(f"Submitted {job_count} jobs")
print("Monitor with: squeue -u $USER")
if __name__ == '__main__':
main()