2023-07-06 18:06:20 +02:00
|
|
|
import wandb
|
|
|
|
import yaml
|
|
|
|
import os
|
2023-07-07 13:10:06 +02:00
|
|
|
import math
|
2023-07-06 18:06:20 +02:00
|
|
|
import random
|
|
|
|
import copy
|
2023-07-29 13:03:01 +02:00
|
|
|
import re
|
|
|
|
import itertools
|
|
|
|
from collections.abc import *
|
2023-07-06 18:06:20 +02:00
|
|
|
from functools import partial
|
2023-07-07 14:39:38 +02:00
|
|
|
from multiprocessing import Process
|
2023-07-09 16:12:38 +02:00
|
|
|
from threading import Thread
|
2023-07-12 11:07:33 +02:00
|
|
|
import git
|
|
|
|
import datetime
|
2023-07-12 13:06:14 +02:00
|
|
|
from pprint import pprint
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
import pdb
|
|
|
|
d = pdb.set_trace
|
|
|
|
|
2023-07-07 16:40:30 +02:00
|
|
|
REQUIRE_CONFIG_CONSUMED = False
|
2023-07-27 11:29:06 +02:00
|
|
|
DEFAULT_START_METHOD = 'fork'
|
|
|
|
DEFAULT_REINIT = True
|
2023-07-07 16:40:30 +02:00
|
|
|
|
2023-07-27 14:39:05 +02:00
|
|
|
Parallelization_Primitive = Process # Thread
|
2023-07-09 16:12:38 +02:00
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
try:
|
|
|
|
import pyslurm
|
|
|
|
except ImportError:
|
|
|
|
slurm_avaible = False
|
2023-07-07 12:10:46 +02:00
|
|
|
print('[!] Slurm not avaible.')
|
2023-07-06 18:06:20 +02:00
|
|
|
else:
|
|
|
|
slurm_avaible = True
|
|
|
|
|
|
|
|
# TODO: Implement Testing
|
|
|
|
# TODO: Implement Ablative
|
|
|
|
|
|
|
|
|
|
|
|
class Slate():
|
|
|
|
def __init__(self, runners):
|
2023-07-27 11:33:45 +02:00
|
|
|
self.runners = {
|
2023-07-29 14:28:23 +02:00
|
|
|
'void': Void_Runner,
|
2023-07-27 11:33:45 +02:00
|
|
|
'printConfig': Print_Config_Runner,
|
|
|
|
'pdb': PDB_Runner,
|
|
|
|
}
|
|
|
|
self.runners.update(runners)
|
2023-07-12 12:23:18 +02:00
|
|
|
self._version = False
|
2023-07-31 15:34:10 +02:00
|
|
|
self.job_id = os.environ.get('SLURM_JOB_ID', False)
|
|
|
|
self.task_id = None
|
2023-07-31 15:38:46 +02:00
|
|
|
self.run_id = -1
|
2023-07-31 15:34:10 +02:00
|
|
|
self._tmp_path = os.path.expandvars('$TMP')
|
2023-07-27 12:50:38 +02:00
|
|
|
self.sweep_id = None
|
2023-09-02 20:54:53 +02:00
|
|
|
self.verify = False
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
def load_config(self, filename, name):
|
|
|
|
config, stack = self._load_config(filename, name)
|
|
|
|
print('[i] Merged Configs: ', stack)
|
2023-07-09 16:12:38 +02:00
|
|
|
self._config = copy.deepcopy(config)
|
2023-07-06 18:06:20 +02:00
|
|
|
self.consume(config, 'vars', {})
|
|
|
|
return config
|
|
|
|
|
|
|
|
def _load_config(self, filename, name, stack=[]):
|
|
|
|
stack.append(f'{filename}:{name}')
|
|
|
|
with open(filename, 'r') as f:
|
|
|
|
docs = yaml.safe_load_all(f)
|
|
|
|
for doc in docs:
|
|
|
|
if 'name' in doc:
|
|
|
|
if doc['name'] == name:
|
|
|
|
if 'import' in doc:
|
|
|
|
imports = doc['import'].split(',')
|
|
|
|
del doc['import']
|
|
|
|
for imp in imports:
|
|
|
|
if imp[0] == ' ':
|
|
|
|
imp = imp[1:]
|
|
|
|
if imp == "$":
|
|
|
|
imp = ':DEFAULT'
|
|
|
|
rel_path, *opt = imp.split(':')
|
|
|
|
if len(opt) == 0:
|
|
|
|
nested_name = 'DEFAULT'
|
|
|
|
elif len(opt) == 1:
|
|
|
|
nested_name = opt[0]
|
|
|
|
else:
|
|
|
|
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
|
|
|
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
|
|
|
child, stack = self._load_config(nested_path, nested_name, stack=stack)
|
|
|
|
doc = self.deep_update(child, doc)
|
|
|
|
return doc, stack
|
|
|
|
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
|
|
|
|
2023-07-27 13:07:18 +02:00
|
|
|
def deep_update(self, d, u, traverse_dot_notation=True):
|
2023-07-06 18:06:20 +02:00
|
|
|
for kstr, v in u.items():
|
2023-07-27 13:07:18 +02:00
|
|
|
if traverse_dot_notation:
|
|
|
|
ks = kstr.split('.')
|
|
|
|
else:
|
|
|
|
ks = [kstr]
|
2023-07-06 18:06:20 +02:00
|
|
|
head = d
|
|
|
|
for k in ks:
|
2023-07-27 13:07:18 +02:00
|
|
|
if k in ['parameters']:
|
|
|
|
traverse_dot_notation = False
|
2023-07-06 18:06:20 +02:00
|
|
|
last_head = head
|
|
|
|
if k not in head:
|
|
|
|
head[k] = {}
|
|
|
|
head = head[k]
|
2023-07-29 13:03:01 +02:00
|
|
|
if isinstance(v, Mapping):
|
2023-07-27 13:07:18 +02:00
|
|
|
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
|
2023-07-06 18:06:20 +02:00
|
|
|
else:
|
|
|
|
last_head[ks[-1]] = v
|
|
|
|
return d
|
|
|
|
|
2023-07-29 13:03:01 +02:00
|
|
|
def expand_vars(self, string, delta_desc='BASE', **kwargs):
|
2023-07-06 18:06:20 +02:00
|
|
|
if isinstance(string, str):
|
2023-07-10 11:26:46 +02:00
|
|
|
rand = int(random.random()*99999999)
|
2023-07-07 16:02:45 +02:00
|
|
|
if string == '{rand}':
|
2023-07-10 11:26:46 +02:00
|
|
|
return rand
|
2023-07-31 15:38:46 +02:00
|
|
|
return string.format(delta_desc=delta_desc, **kwargs, rand=rand, tmp=self._tmp_path, job_id=(self.job_id or 'LOCAL'), task_id=(self.task_id or 0), run_id=self.run_id)
|
2023-07-06 18:06:20 +02:00
|
|
|
return string
|
|
|
|
|
|
|
|
def apply_nested(self, d, f):
|
|
|
|
for k, v in d.items():
|
|
|
|
if isinstance(v, dict):
|
|
|
|
self.apply_nested(v, f)
|
|
|
|
elif isinstance(v, list):
|
|
|
|
for i, e in enumerate(v):
|
2023-07-12 11:07:33 +02:00
|
|
|
ptr = {'PTR': d[k][i]}
|
|
|
|
self.apply_nested(ptr, f)
|
|
|
|
d[k][i] = ptr['PTR']
|
2023-07-06 18:06:20 +02:00
|
|
|
else:
|
|
|
|
d[k] = f(v)
|
|
|
|
|
|
|
|
def deep_expand_vars(self, dict, **kwargs):
|
|
|
|
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
|
|
|
|
2023-07-09 16:31:35 +02:00
|
|
|
def consume(self, conf, key, default=None, expand=False, **kwargs):
|
2023-07-29 13:03:01 +02:00
|
|
|
if key == '':
|
|
|
|
if expand:
|
|
|
|
self.deep_expand_vars(conf, config=self._config, **kwargs)
|
|
|
|
elif type(conf) == str:
|
|
|
|
while conf.find('{') != -1:
|
|
|
|
conf = self.expand_vars(conf, config=self._config, **kwargs)
|
|
|
|
return conf
|
2023-07-06 18:06:20 +02:00
|
|
|
keys_arr = key.split('.')
|
|
|
|
if len(keys_arr) == 1:
|
|
|
|
k = keys_arr[0]
|
|
|
|
if default != None:
|
2023-07-29 13:03:01 +02:00
|
|
|
if isinstance(conf, Mapping):
|
2023-07-12 11:46:52 +02:00
|
|
|
val = conf.get(k, default)
|
|
|
|
else:
|
2023-07-12 11:49:30 +02:00
|
|
|
if default != None:
|
2023-07-12 11:46:52 +02:00
|
|
|
return default
|
|
|
|
raise Exception('')
|
2023-07-06 18:06:20 +02:00
|
|
|
else:
|
|
|
|
val = conf[k]
|
|
|
|
if k in conf:
|
|
|
|
del conf[k]
|
2023-07-09 16:12:38 +02:00
|
|
|
|
2023-07-09 16:31:35 +02:00
|
|
|
if expand:
|
|
|
|
self.deep_expand_vars(val, config=self._config, **kwargs)
|
|
|
|
elif type(val) == str:
|
2023-07-09 16:16:24 +02:00
|
|
|
while val.find('{') != -1:
|
|
|
|
val = self.expand_vars(val, config=self._config, **kwargs)
|
2023-07-09 16:12:38 +02:00
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
return val
|
|
|
|
child = conf.get(keys_arr[0], {})
|
|
|
|
child_keys = '.'.join(keys_arr[1:])
|
2023-07-12 11:07:33 +02:00
|
|
|
return self.consume(child, child_keys, default=default, expand=expand, **kwargs)
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-07-12 12:23:18 +02:00
|
|
|
def get_version(self):
|
|
|
|
if not self._version:
|
|
|
|
repo = git.Repo(search_parent_directories=True)
|
|
|
|
sha = repo.head.object.hexsha
|
|
|
|
self._version = sha
|
|
|
|
return self._version
|
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
def _calc_num_jobs(self, schedC, num_conv_versions):
|
2023-07-09 16:21:33 +02:00
|
|
|
schedulerC = copy.deepcopy(schedC)
|
2023-07-29 14:28:23 +02:00
|
|
|
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
2023-07-09 16:23:13 +02:00
|
|
|
agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
|
|
|
|
reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
|
2023-07-07 14:39:38 +02:00
|
|
|
reps_per_job = reps_per_agent * agents_per_job
|
|
|
|
jobs_needed = math.ceil(reps / reps_per_job)
|
|
|
|
return jobs_needed
|
|
|
|
|
2023-07-31 15:34:10 +02:00
|
|
|
def _reps_for_job(self, schedC, task_id, num_conv_versions):
|
2023-07-09 16:21:33 +02:00
|
|
|
schedulerC = copy.deepcopy(schedC)
|
2023-07-29 14:28:23 +02:00
|
|
|
num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
|
2023-07-29 14:39:30 +02:00
|
|
|
reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
|
2023-07-31 15:34:10 +02:00
|
|
|
if task_id == None:
|
2023-07-07 16:40:30 +02:00
|
|
|
return list(range(0, reps))
|
2023-07-29 14:53:19 +02:00
|
|
|
reps_for_job = [[] for i in range(num_jobs)]
|
2023-07-07 14:39:38 +02:00
|
|
|
for i in range(reps):
|
|
|
|
reps_for_job[i % num_jobs].append(i)
|
2023-07-31 15:34:10 +02:00
|
|
|
return reps_for_job[task_id]
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-07-31 15:34:10 +02:00
|
|
|
def run_local(self, filename, name, task_id, sweep_id):
|
|
|
|
self.task_id = task_id
|
2023-07-06 18:06:20 +02:00
|
|
|
config = self.load_config(filename, name)
|
2023-07-29 14:28:23 +02:00
|
|
|
num_conv_versions = self._get_num_conv_versions(config)
|
2023-07-07 14:39:38 +02:00
|
|
|
schedulerC = copy.deepcopy(config.get('scheduler', {}))
|
2023-07-31 15:34:10 +02:00
|
|
|
rep_ids = self._reps_for_job(schedulerC, task_id, num_conv_versions)
|
2023-07-27 12:34:36 +02:00
|
|
|
self.sweep_id = sweep_id
|
2023-07-27 12:49:30 +02:00
|
|
|
self._init_sweep(config)
|
2023-07-07 14:39:38 +02:00
|
|
|
self._fork_processes(config, rep_ids)
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-07-12 12:23:18 +02:00
|
|
|
def run_slurm(self, filename, name):
|
2023-07-06 18:06:20 +02:00
|
|
|
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
|
|
|
config = self.load_config(filename, name)
|
2023-07-09 17:15:39 +02:00
|
|
|
slurmC = self.consume(config, 'slurm', expand=True)
|
2023-07-07 14:39:38 +02:00
|
|
|
schedC = self.consume(config, 'scheduler')
|
2023-07-06 18:06:20 +02:00
|
|
|
s_name = self.consume(slurmC, 'name')
|
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
num_conv_versions = self._get_num_conv_versions(config)
|
|
|
|
|
2023-07-12 18:02:10 +02:00
|
|
|
# Pre Validation
|
2023-07-27 12:49:30 +02:00
|
|
|
runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
|
2023-07-12 13:06:14 +02:00
|
|
|
if self.consume(slurmC, 'pre_validate', True):
|
|
|
|
Runner = self.runners[runnerName]
|
|
|
|
runner = Runner(self, config)
|
|
|
|
runner.setup()
|
|
|
|
|
2023-07-27 12:34:36 +02:00
|
|
|
self._init_sweep(config)
|
2023-07-27 12:49:30 +02:00
|
|
|
self.consume(config, 'wandb')
|
2023-07-27 12:34:36 +02:00
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
python_script = 'main.py'
|
2023-07-07 12:32:12 +02:00
|
|
|
sh_lines = ['#!/bin/bash']
|
|
|
|
sh_lines += self.consume(slurmC, 'sh_lines', [])
|
2023-07-06 18:06:20 +02:00
|
|
|
if venv := self.consume(slurmC, 'venv', False):
|
|
|
|
sh_lines += [f'source activate {venv}']
|
2023-08-01 11:51:38 +02:00
|
|
|
sh_lines += [f'python3 {python_script} {filename} {name} -t $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
|
2023-07-07 12:38:07 +02:00
|
|
|
script = "\n".join(sh_lines)
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
num_jobs = self._calc_num_jobs(schedC, num_conv_versions)
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
last_job_idx = num_jobs - 1
|
2023-07-07 12:25:54 +02:00
|
|
|
num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
|
2023-07-06 18:06:20 +02:00
|
|
|
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
2023-07-07 12:27:56 +02:00
|
|
|
job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
|
2023-09-02 20:54:53 +02:00
|
|
|
if self.verify:
|
|
|
|
input(f'<Press enter to submit {num_jobs} to slurm>')
|
2023-07-06 18:06:20 +02:00
|
|
|
job_id = job.submit()
|
2023-07-12 11:49:30 +02:00
|
|
|
print(f'[>] Job submitted to slurm with id {job_id}')
|
2023-07-12 11:38:58 +02:00
|
|
|
with open('job_hist.log', 'a') as f:
|
2023-07-12 12:23:18 +02:00
|
|
|
f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-07-07 14:39:38 +02:00
|
|
|
def _fork_processes(self, config, rep_ids):
|
2023-07-29 13:03:01 +02:00
|
|
|
schedC = self.consume(config, 'scheduler', {})
|
2023-07-07 14:39:38 +02:00
|
|
|
agents_per_job = self.consume(schedC, 'agents_per_job', 1)
|
|
|
|
reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)
|
2023-07-07 13:10:06 +02:00
|
|
|
|
2023-07-07 14:39:38 +02:00
|
|
|
node_reps = len(rep_ids)
|
|
|
|
num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))
|
2023-07-07 13:10:06 +02:00
|
|
|
|
|
|
|
if num_p == 1:
|
2023-07-09 16:44:09 +02:00
|
|
|
print('[i] Running within main thread')
|
2023-07-27 12:34:36 +02:00
|
|
|
self._run_process(config, rep_ids=rep_ids, p_ind=0)
|
2023-07-07 14:39:38 +02:00
|
|
|
return
|
|
|
|
|
|
|
|
procs = []
|
2023-07-07 13:10:06 +02:00
|
|
|
|
2023-07-07 14:39:38 +02:00
|
|
|
reps_done = 0
|
2023-07-07 13:10:06 +02:00
|
|
|
|
2023-07-07 14:39:38 +02:00
|
|
|
for p in range(num_p):
|
2023-07-09 16:44:09 +02:00
|
|
|
print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
|
2023-07-07 14:39:38 +02:00
|
|
|
num_reps = min(node_reps - reps_done, reps_per_agent)
|
|
|
|
proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
|
2023-07-09 16:12:38 +02:00
|
|
|
proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
|
2023-07-07 14:39:38 +02:00
|
|
|
proc.start()
|
|
|
|
procs.append(proc)
|
|
|
|
reps_done += num_reps
|
|
|
|
|
|
|
|
for proc in procs:
|
|
|
|
proc.join()
|
2023-07-09 16:44:09 +02:00
|
|
|
print(f'[i] All threads/processes have terminated')
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-07-27 12:34:36 +02:00
|
|
|
def _init_sweep(self, config):
|
2023-07-27 12:49:30 +02:00
|
|
|
if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
|
2023-07-07 14:39:38 +02:00
|
|
|
sweepC = self.consume(config, 'sweep')
|
2023-07-27 11:29:06 +02:00
|
|
|
wandbC = copy.deepcopy(config['wandb'])
|
2023-07-27 12:42:02 +02:00
|
|
|
project = self.consume(wandbC, 'project')
|
2023-07-27 12:34:36 +02:00
|
|
|
|
|
|
|
self.sweep_id = wandb.sweep(
|
2023-07-07 14:39:38 +02:00
|
|
|
sweep=sweepC,
|
2023-07-27 12:54:17 +02:00
|
|
|
project=project
|
2023-07-07 14:39:38 +02:00
|
|
|
)
|
2023-07-27 12:34:36 +02:00
|
|
|
|
|
|
|
def _run_process(self, orig_config, rep_ids, p_ind):
|
|
|
|
config = copy.deepcopy(orig_config)
|
|
|
|
if self.consume(config, 'sweep.enable', False):
|
2023-07-29 11:53:29 +02:00
|
|
|
wandbC = copy.deepcopy(config['wandb'])
|
|
|
|
wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
|
2023-07-07 14:39:38 +02:00
|
|
|
else:
|
|
|
|
self.consume(config, 'sweep', {})
|
|
|
|
self._run_single(config, rep_ids, p_ind=p_ind)
|
|
|
|
|
|
|
|
def _run_single(self, orig_config, rep_ids, p_ind):
|
|
|
|
print(f'[P{p_ind}] I will work on reps {rep_ids}')
|
2023-07-29 13:03:01 +02:00
|
|
|
runnerName = self.consume(orig_config, 'runner')
|
|
|
|
project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-07-12 13:06:14 +02:00
|
|
|
Runner = self.runners[runnerName]
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-09-02 19:55:35 +02:00
|
|
|
if orig_config.consume('scheduler.bind_agent_to_core', False):
|
|
|
|
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
|
|
|
|
|
2023-07-07 14:39:38 +02:00
|
|
|
for r in rep_ids:
|
2023-07-31 15:38:46 +02:00
|
|
|
self.run_id = r
|
2023-07-07 14:39:38 +02:00
|
|
|
config = copy.deepcopy(orig_config)
|
2023-07-29 13:03:01 +02:00
|
|
|
runnerConf = self._make_config_for_run(config, r)
|
|
|
|
wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
|
2023-07-07 14:39:38 +02:00
|
|
|
with wandb.init(
|
|
|
|
project=project,
|
2023-07-29 13:03:01 +02:00
|
|
|
config=copy.deepcopy(runnerConf),
|
2023-07-27 11:29:06 +02:00
|
|
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
2023-07-27 15:10:52 +02:00
|
|
|
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
2023-07-07 14:39:38 +02:00
|
|
|
**wandbC
|
|
|
|
) as run:
|
2023-07-29 13:03:01 +02:00
|
|
|
runner = Runner(self, runnerConf)
|
2023-07-12 13:06:14 +02:00
|
|
|
runner.setup()
|
|
|
|
runner.run(run)
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-07-29 13:03:01 +02:00
|
|
|
if runnerConf != {}:
|
|
|
|
msg = ('Config was not completely consumed: ', runnerConf)
|
2023-07-08 12:52:59 +02:00
|
|
|
if REQUIRE_CONFIG_CONSUMED:
|
|
|
|
raise Exception(msg)
|
|
|
|
else:
|
|
|
|
print(msg)
|
2023-07-29 13:03:01 +02:00
|
|
|
orig_config = {}
|
2023-07-07 14:39:38 +02:00
|
|
|
|
2023-07-29 11:53:29 +02:00
|
|
|
def _run_from_sweep(self, orig_config, p_ind):
|
|
|
|
runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
|
|
|
|
project = self.consume(wandbC, 'project')
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-07-12 13:06:14 +02:00
|
|
|
Runner = self.runners[runnerName]
|
2023-07-06 18:06:20 +02:00
|
|
|
|
2023-09-02 19:55:35 +02:00
|
|
|
if orig_config.consume('scheduler.bind_agent_to_core', False):
|
|
|
|
os.sched_setaffinity(0, [p_ind % os.cpu_count()])
|
|
|
|
|
2023-07-29 11:53:29 +02:00
|
|
|
with wandb.init(
|
|
|
|
project=project,
|
|
|
|
reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
|
|
|
|
settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
|
|
|
|
**wandbC
|
|
|
|
) as run:
|
|
|
|
config = copy.deepcopy(orig_config)
|
|
|
|
self.deep_update(config, wandb.config)
|
|
|
|
run.config = copy.deepcopy(config)
|
|
|
|
runner = Runner(self, config)
|
|
|
|
runner.setup()
|
|
|
|
runner.run(run)
|
|
|
|
|
|
|
|
if config != {}:
|
|
|
|
msg = ('Config was not completely consumed: ', config)
|
|
|
|
if REQUIRE_CONFIG_CONSUMED:
|
|
|
|
raise Exception(msg)
|
|
|
|
else:
|
|
|
|
print(msg)
|
2023-07-29 13:03:01 +02:00
|
|
|
orig_config = {}
|
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
def _make_configs_for_runs(self, config):
|
2023-07-29 13:03:01 +02:00
|
|
|
c = copy.deepcopy(config)
|
|
|
|
|
|
|
|
grid_versions = self._make_grid_versions(c)
|
|
|
|
all_versions = self._make_ablative_versions(c, grid_versions)
|
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
return all_versions
|
|
|
|
|
|
|
|
def _get_num_conv_versions(self, config):
|
|
|
|
return len(self._make_configs_for_runs(config))
|
|
|
|
|
|
|
|
def _make_config_for_run(self, config, r):
|
|
|
|
all_versions = self._make_configs_for_runs(config)
|
|
|
|
|
2023-07-29 13:03:01 +02:00
|
|
|
i = r % len(all_versions)
|
|
|
|
print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
|
|
|
|
cur_conf = all_versions[i]
|
|
|
|
if 'ablative' in cur_conf:
|
|
|
|
del cur_conf['ablative']
|
|
|
|
return cur_conf
|
|
|
|
|
|
|
|
def _make_grid_versions(self, config):
|
|
|
|
if 'grid' in config:
|
|
|
|
return params_combine(config, 'grid', itertools.product)
|
|
|
|
return [config]
|
|
|
|
|
|
|
|
def _make_ablative_versions(self, config, grid_versions):
|
|
|
|
if 'ablative' in config:
|
|
|
|
return grid_versions + ablative_expand(grid_versions)
|
|
|
|
else:
|
|
|
|
return grid_versions
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
def from_args(self):
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("config_file", nargs='?', default=None)
|
|
|
|
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
|
|
|
parser.add_argument("-s", "--slurm", action="store_true")
|
|
|
|
parser.add_argument("-w", "--worker", action="store_true")
|
2023-07-31 15:34:10 +02:00
|
|
|
parser.add_argument("-t", "--task_id", default=None, type=int)
|
2023-07-27 12:37:59 +02:00
|
|
|
parser.add_argument("--sweep_id", default=None, type=str)
|
2023-09-02 20:54:53 +02:00
|
|
|
parser.add_argument("--ask_verify", action="store_true")
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-07-31 15:34:10 +02:00
|
|
|
print(f'[i] I have task_id {args.task_id}')
|
2023-07-12 12:23:18 +02:00
|
|
|
print(f'[i] Running on version [git:{self.get_version()}]')
|
2023-07-09 17:16:50 +02:00
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
if args.worker:
|
|
|
|
raise Exception('Not yet implemented')
|
|
|
|
|
|
|
|
assert args.config_file != None, 'Need to supply config file.'
|
2023-07-31 15:34:10 +02:00
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
if args.slurm:
|
2023-09-02 20:54:53 +02:00
|
|
|
if args.ask_verify:
|
|
|
|
self.verify = True
|
2023-07-12 12:23:18 +02:00
|
|
|
self.run_slurm(args.config_file, args.experiment)
|
2023-07-06 18:06:20 +02:00
|
|
|
else:
|
2023-07-31 15:34:10 +02:00
|
|
|
self.run_local(args.config_file, args.experiment, args.task_id, args.sweep_id)
|
2023-07-06 18:06:20 +02:00
|
|
|
|
|
|
|
|
2023-07-29 13:03:01 +02:00
|
|
|
def params_combine(config: dict, key: str, iter_func):
|
|
|
|
if iter_func is None:
|
|
|
|
return [config]
|
|
|
|
|
|
|
|
combined_configs = []
|
|
|
|
# convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
|
|
|
|
# value is the list of values
|
|
|
|
tuple_dict = flatten_dict_to_tuple_keys(config[key])
|
|
|
|
_param_names = ['.'.join(t) for t in tuple_dict]
|
|
|
|
|
|
|
|
param_lengths = map(len, tuple_dict.values())
|
|
|
|
|
|
|
|
# create a new config for each parameter setting
|
|
|
|
for values in iter_func(*tuple_dict.values()):
|
|
|
|
_config = copy.deepcopy(config)
|
|
|
|
|
|
|
|
# Remove Grid/List Argument
|
|
|
|
del _config[key]
|
|
|
|
|
|
|
|
# Expand Grid/List Parameters
|
|
|
|
for i, t in enumerate(tuple_dict.keys()):
|
|
|
|
insert_deep_dictionary(d=_config, t=t, value=values[i])
|
|
|
|
|
|
|
|
_config = extend_config_name(_config, _param_names, values)
|
|
|
|
combined_configs.append(_config)
|
|
|
|
return combined_configs
|
|
|
|
|
|
|
|
|
|
|
|
def ablative_expand(conf_list):
|
|
|
|
combined_configs = []
|
|
|
|
for config in conf_list:
|
|
|
|
tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
|
|
|
|
_param_names = ['.'.join(t) for t in tuple_dict]
|
|
|
|
|
|
|
|
for i, key in enumerate(tuple_dict):
|
|
|
|
for val in tuple_dict[key]:
|
|
|
|
_config = copy.deepcopy(config)
|
|
|
|
|
|
|
|
insert_deep_dictionary(
|
|
|
|
_config, key, val
|
|
|
|
)
|
|
|
|
|
|
|
|
_config = extend_config_name(_config, [_param_names[i]], [val])
|
|
|
|
combined_configs.append(_config)
|
|
|
|
return combined_configs
|
|
|
|
|
|
|
|
|
|
|
|
def flatten_dict_to_tuple_keys(d: MutableMapping):
|
|
|
|
flat_dict = {}
|
|
|
|
for k, v in d.items():
|
|
|
|
if isinstance(v, MutableMapping):
|
|
|
|
sub_dict = flatten_dict_to_tuple_keys(v)
|
|
|
|
flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})
|
|
|
|
|
|
|
|
elif isinstance(v, MutableSequence):
|
|
|
|
flat_dict[(k,)] = v
|
|
|
|
|
|
|
|
return flat_dict
|
|
|
|
|
|
|
|
|
|
|
|
def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
|
|
|
|
if type(t) is tuple:
|
|
|
|
if len(t) == 1: # tuple contains only one key
|
|
|
|
d[t[0]] = value
|
|
|
|
else: # tuple contains more than one key
|
|
|
|
if t[0] not in d:
|
|
|
|
d[t[0]] = dict()
|
|
|
|
insert_deep_dictionary(d[t[0]], t[1:], value)
|
|
|
|
else:
|
|
|
|
d[t] = value
|
|
|
|
|
|
|
|
|
|
|
|
def append_deep_dictionary(d: MutableMapping, t: tuple, value):
|
|
|
|
if type(t) is tuple:
|
|
|
|
if len(t) == 1: # tuple contains only one key
|
|
|
|
if t[0] not in d:
|
|
|
|
d[t[0]] = []
|
|
|
|
d[t[0]].append(value)
|
|
|
|
else: # tuple contains more than one key
|
|
|
|
if t[0] not in d:
|
|
|
|
d[t[0]] = dict()
|
|
|
|
append_deep_dictionary(d[t[0]], t[1:], value)
|
|
|
|
else:
|
|
|
|
d[t] = value
|
|
|
|
|
|
|
|
|
|
|
|
def extend_config_name(config: dict, param_names: list, values: list) -> dict:
|
|
|
|
_converted_name = convert_param_names(param_names, values)
|
|
|
|
|
|
|
|
config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
def convert_param_names(_param_names: list, values: list) -> str:
|
|
|
|
_converted_name = '_'.join("{}{}".format(
|
|
|
|
shorten_param(k), v) for k, v in zip(_param_names, values))
|
|
|
|
# _converted_name = re.sub("[' \[\],()]", '', _converted_name)
|
|
|
|
_converted_name = re.sub("[' ]", '', _converted_name)
|
|
|
|
_converted_name = re.sub('["]', '', _converted_name)
|
|
|
|
_converted_name = re.sub("[(\[]", '_', _converted_name)
|
|
|
|
_converted_name = re.sub("[)\]]", '', _converted_name)
|
|
|
|
_converted_name = re.sub("[,]", '_', _converted_name)
|
|
|
|
return _converted_name
|
|
|
|
|
|
|
|
|
|
|
|
def shorten_param(_param_name):
|
|
|
|
name_parts = _param_name.split('.')
|
|
|
|
shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
|
|
|
|
shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
|
|
|
|
if shortened_parts:
|
|
|
|
return shortened_parts + '.' + shortened_leaf
|
|
|
|
else:
|
|
|
|
return shortened_leaf
|
|
|
|
|
|
|
|
|
2023-07-12 13:06:14 +02:00
|
|
|
class Slate_Runner():
|
|
|
|
def __init__(self, slate, config):
|
|
|
|
self.slate = slate
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def run(self, run):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class Print_Config_Runner(Slate_Runner):
|
|
|
|
def run(self, run):
|
|
|
|
slate, config = self.slate, self.config
|
|
|
|
|
|
|
|
pprint(config)
|
|
|
|
print('---')
|
2023-07-29 13:03:01 +02:00
|
|
|
pprint(slate.consume(config, '', expand=True))
|
2023-07-12 13:06:14 +02:00
|
|
|
for k in list(config.keys()):
|
|
|
|
del config[k]
|
2023-07-06 18:20:37 +02:00
|
|
|
|
|
|
|
|
2023-07-29 14:28:23 +02:00
|
|
|
class Void_Runner(Slate_Runner):
|
|
|
|
def run(self, run):
|
|
|
|
slate, config = self.slate, self.config
|
|
|
|
for k in list(config.keys()):
|
|
|
|
del config[k]
|
|
|
|
|
|
|
|
|
2023-07-12 13:06:14 +02:00
|
|
|
class PDB_Runner(Slate_Runner):
|
|
|
|
def run(self, run):
|
|
|
|
d()
|
2023-07-12 12:23:18 +02:00
|
|
|
|
|
|
|
|
2023-07-06 18:06:20 +02:00
|
|
|
if __name__ == '__main__':
|
|
|
|
raise Exception('You are using it wrong...')
|