reppo/submit_dmc_experiments.py
ys1087@partner.kit.edu a02e258f1c seperate dmc setup...
2025-07-29 14:58:43 +02:00

98 lines
2.6 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Submit DMC (mujoco_playground) experiments for REPPO
"""
import subprocess
import time
import argparse
# List of 23 DMC tasks from experiment plan
DMC_TASKS = [
"AcrobotSwingup",
"CartpoleBalance",
"CartpoleSwingup",
"CheetahRun",
"FingerSpin",
"FingerTurnEasy",
"FingerTurnHard",
"FishUpright",
"FishSwim",
"HopperStand",
"HopperHop",
"HumanoidStand",
"HumanoidWalk",
"HumanoidRun",
"ManipulatorBringBall",
"PendulumSwingup",
"PointMassEasy",
"ReacherEasy",
"ReacherHard",
"SwimmerSwimmer6",
"SwimmerSwimmer15",
"WalkerStand",
"WalkerWalk",
"WalkerRun"
]
def submit_job(env_name, seed=0):
"""Submit a single DMC job"""
cmd = [
"sbatch",
f"--job-name=reppo_dmc_{env_name}_seed{seed}",
"slurm/run_reppo_dmc_prod.sh"
]
env_vars = {
"ENV_NAME": env_name,
"SEED": str(seed)
}
print(f"Submitting {env_name} (seed {seed})...")
try:
result = subprocess.run(cmd, env={**subprocess.os.environ, **env_vars},
capture_output=True, text=True, check=True)
job_id = result.stdout.strip().split()[-1]
print(f" -> Job ID: {job_id}")
return job_id
except subprocess.CalledProcessError as e:
print(f" -> Error: {e}")
print(f" -> Stdout: {e.stdout}")
print(f" -> Stderr: {e.stderr}")
return None
def main():
parser = argparse.ArgumentParser(description="Submit DMC experiments")
parser.add_argument("--seeds", type=int, default=5, help="Number of seeds to run")
parser.add_argument("--tasks", nargs="+", default=DMC_TASKS,
help="List of tasks to run")
parser.add_argument("--delay", type=float, default=1.0,
help="Delay between submissions (seconds)")
args = parser.parse_args()
print(f"Submitting {len(args.tasks)} DMC tasks with {args.seeds} seeds each")
print(f"Total jobs: {len(args.tasks) * args.seeds}")
print()
job_ids = []
for task in args.tasks:
for seed in range(args.seeds):
job_id = submit_job(task, seed)
if job_id:
job_ids.append(job_id)
# Add delay to avoid overwhelming the scheduler
time.sleep(args.delay)
print(f"\nSubmitted {len(job_ids)} jobs successfully:")
for i, job_id in enumerate(job_ids):
print(f" {i+1}: {job_id}")
print(f"\nMonitor with: squeue -u $USER")
print(f"Check logs in: logs/")
if __name__ == "__main__":
main()