Maybe it's a package now
This commit is contained in:
parent
def7c55f8e
commit
ea74d6a712
73
example.py
Normal file
73
example.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from slate import Slate
|
||||||
|
|
||||||
|
import fancy_gym
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||||
|
from wandb.integration.sb3 import WandbCallback
|
||||||
|
import gymnasium as gym
|
||||||
|
import copy
|
||||||
|
|
||||||
|
PCA = None
|
||||||
|
|
||||||
|
|
||||||
|
def debug_runner(slate, run, config):
|
||||||
|
print(config)
|
||||||
|
for k in list(config.keys()):
|
||||||
|
del config[k]
|
||||||
|
import time
|
||||||
|
time.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
|
def sb3_runner(slate, run, config):
|
||||||
|
videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config,
|
||||||
|
'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {})
|
||||||
|
assert config == {}
|
||||||
|
|
||||||
|
env = DummyVecEnv([make_env_func(slate, envC)])
|
||||||
|
if slate.consume(videoC, 'enable', False):
|
||||||
|
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
||||||
|
|
||||||
|
assert slate.consume(algoC, 'name') == 'PPO'
|
||||||
|
policy_name = slate.consume(algoC, 'policy_name')
|
||||||
|
|
||||||
|
total_timesteps = config.get('run', {}).get('total_timesteps', {})
|
||||||
|
|
||||||
|
model = PPO(policy_name, env, **algoC)
|
||||||
|
|
||||||
|
if slate.consume(pcaC, 'enable', False):
|
||||||
|
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
callback=WandbCallback()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_env_func(slate, env_conf):
|
||||||
|
conf = copy.deepcopy(env_conf)
|
||||||
|
name = slate.consume(conf, 'name')
|
||||||
|
legacy_fancy = slate.consume(conf, 'legacy_fancy', False)
|
||||||
|
wrappers = slate.consume(conf, 'wrappers', [])
|
||||||
|
|
||||||
|
def func():
|
||||||
|
if legacy_fancy: # TODO: Remove when no longer needed
|
||||||
|
fancy_gym.make(name, **conf)
|
||||||
|
else:
|
||||||
|
env = gym.make(name, **conf)
|
||||||
|
|
||||||
|
# TODO: Implement wrappers
|
||||||
|
|
||||||
|
env = Monitor(env)
|
||||||
|
return env
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
runners = {
|
||||||
|
'sb3': sb3_runner,
|
||||||
|
'debug': debug_runner
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
slate = Slate(runners)
|
||||||
|
slate.from_args()
|
271
main.py
271
main.py
@ -1,271 +0,0 @@
|
|||||||
#import fancy_gym
|
|
||||||
#from stable_baselines3 import PPO
|
|
||||||
#from stable_baselines3.common.monitor import Monitor
|
|
||||||
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
|
||||||
import wandb
|
|
||||||
from wandb.integration.sb3 import WandbCallback
|
|
||||||
#import gymnasium as gym
|
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import copy
|
|
||||||
import collections.abc
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import pdb
|
|
||||||
d = pdb.set_trace
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pyslurm
|
|
||||||
except ImportError:
|
|
||||||
slurm_avaible = False
|
|
||||||
else:
|
|
||||||
slurm_avaible = True
|
|
||||||
|
|
||||||
PCA = None
|
|
||||||
|
|
||||||
# TODO: Implement Slurm
|
|
||||||
# TODO: Implement Parallel
|
|
||||||
# TODO: Implement Testing
|
|
||||||
# TODO: Implement Ablative
|
|
||||||
# TODO: Implement PCA
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(filename, name):
|
|
||||||
config, stack = _load_config(filename, name)
|
|
||||||
print('[i] Merged Configs: ', stack)
|
|
||||||
deep_expand_vars(config, config=config)
|
|
||||||
consume(config, 'vars', {})
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _load_config(filename, name, stack=[]):
|
|
||||||
stack.append(f'{filename}:{name}')
|
|
||||||
with open(filename, 'r') as f:
|
|
||||||
docs = yaml.safe_load_all(f)
|
|
||||||
for doc in docs:
|
|
||||||
if 'name' in doc:
|
|
||||||
if doc['name'] == name:
|
|
||||||
if 'import' in doc:
|
|
||||||
imports = doc['import'].split(',')
|
|
||||||
del doc['import']
|
|
||||||
for imp in imports:
|
|
||||||
if imp[0] == ' ':
|
|
||||||
imp = imp[1:]
|
|
||||||
if imp == "$":
|
|
||||||
imp = ':DEFAULT'
|
|
||||||
rel_path, *opt = imp.split(':')
|
|
||||||
if len(opt) == 0:
|
|
||||||
nested_name = 'DEFAULT'
|
|
||||||
elif len(opt) == 1:
|
|
||||||
nested_name = opt[0]
|
|
||||||
else:
|
|
||||||
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
|
||||||
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
|
||||||
child, stack = _load_config(nested_path, nested_name, stack=stack)
|
|
||||||
doc = deep_update(child, doc)
|
|
||||||
return doc, stack
|
|
||||||
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
|
||||||
|
|
||||||
|
|
||||||
def deep_update(d, u):
|
|
||||||
for kstr, v in u.items():
|
|
||||||
ks = kstr.split('.')
|
|
||||||
head = d
|
|
||||||
for k in ks:
|
|
||||||
last_head = head
|
|
||||||
if k not in head:
|
|
||||||
head[k] = {}
|
|
||||||
head = head[k]
|
|
||||||
if isinstance(v, collections.abc.Mapping):
|
|
||||||
last_head[ks[-1]] = deep_update(d.get(k, {}), v)
|
|
||||||
else:
|
|
||||||
last_head[ks[-1]] = v
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def expand_vars(string, **kwargs):
|
|
||||||
if isinstance(string, str):
|
|
||||||
return string.format(**kwargs)
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def apply_nested(d, f):
|
|
||||||
for k, v in d.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
apply_nested(v, f)
|
|
||||||
elif isinstance(v, list):
|
|
||||||
for i, e in enumerate(v):
|
|
||||||
apply_nested({'PTR': d[k][i]}, f)
|
|
||||||
else:
|
|
||||||
d[k] = f(v)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_expand_vars(dict, **kwargs):
|
|
||||||
apply_nested(dict, lambda x: expand_vars(x, **kwargs))
|
|
||||||
|
|
||||||
|
|
||||||
def consume(conf, key, default=None):
|
|
||||||
keys_arr = key.split('.')
|
|
||||||
if len(keys_arr) == 1:
|
|
||||||
k = keys_arr[0]
|
|
||||||
if default != None:
|
|
||||||
val = conf.get(k, default)
|
|
||||||
else:
|
|
||||||
val = conf[k]
|
|
||||||
if k in conf:
|
|
||||||
del conf[k]
|
|
||||||
return val
|
|
||||||
child = conf.get(keys_arr[0], {})
|
|
||||||
child_keys = '.'.join(keys_arr[1:])
|
|
||||||
return consume(child, child_keys, default=default)
|
|
||||||
|
|
||||||
|
|
||||||
def run_local(filename, name, job_num=None):
|
|
||||||
config = load_config(filename, name)
|
|
||||||
if consume(config, 'sweep.enable', False):
|
|
||||||
sweepC = consume(config, 'sweep')
|
|
||||||
project = consume(config, 'wandb.project')
|
|
||||||
sweep_id = wandb.sweep(
|
|
||||||
sweep=sweepC,
|
|
||||||
project=project
|
|
||||||
)
|
|
||||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
|
||||||
wandb.agent(sweep_id, function=partial(run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
|
||||||
else:
|
|
||||||
consume(config, 'sweep', {})
|
|
||||||
run_single(config)
|
|
||||||
|
|
||||||
|
|
||||||
def run_from_sweep(orig_config, runnerName, project, wandbC):
|
|
||||||
runner = Runners[runnerName]
|
|
||||||
|
|
||||||
with wandb.init(
|
|
||||||
project=project,
|
|
||||||
**wandbC
|
|
||||||
) as run:
|
|
||||||
config = copy.deepcopy(orig_config)
|
|
||||||
deep_update(config, wandb.config)
|
|
||||||
runner(run, config)
|
|
||||||
|
|
||||||
assert config == {}, ('Config was not completely consumed: ', config)
|
|
||||||
|
|
||||||
|
|
||||||
def run_slurm(filename, name):
|
|
||||||
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
|
||||||
config = load_config(filename, name)
|
|
||||||
slurmC = consume(config, 'slurm')
|
|
||||||
s_name = consume(slurmC, 'name')
|
|
||||||
|
|
||||||
python_script = 'main.py'
|
|
||||||
sh_lines = consume(slurmC, 'sh_lines', [])
|
|
||||||
if venv := consume(slurmC, 'venv', False):
|
|
||||||
sh_lines += [f'source activate {venv}']
|
|
||||||
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
|
||||||
script = " && ".join(sh_lines)
|
|
||||||
|
|
||||||
num_jobs = 1
|
|
||||||
|
|
||||||
last_job_idx = num_jobs - 1
|
|
||||||
num_parallel_jobs = min(consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs)
|
|
||||||
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
|
||||||
job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm'])
|
|
||||||
job_id = job.submit()
|
|
||||||
print(f'[i] Job submitted to slurm with id {job_id}')
|
|
||||||
|
|
||||||
|
|
||||||
def run_single(config):
|
|
||||||
runnerName, wandbC = consume(config, 'runner'), consume(config, 'wandb', {})
|
|
||||||
|
|
||||||
runner = Runners[runnerName]
|
|
||||||
|
|
||||||
with wandb.init(
|
|
||||||
project=consume(wandbC, 'project'),
|
|
||||||
config=config,
|
|
||||||
**wandbC
|
|
||||||
) as run:
|
|
||||||
runner(run, config)
|
|
||||||
|
|
||||||
assert config == {}, ('Config was not completely consumed: ', config)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("config_file", nargs='?', default=None)
|
|
||||||
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
|
||||||
parser.add_argument("-s", "--slurm", action="store_true")
|
|
||||||
parser.add_argument("-w", "--worker", action="store_true")
|
|
||||||
parser.add_argument("-j", "--job_num", default=None)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.worker:
|
|
||||||
raise Exception('Not yet implemented')
|
|
||||||
|
|
||||||
assert args.config_file != None, 'Need to supply config file.'
|
|
||||||
if args.slurm:
|
|
||||||
run_slurm(args.config_file, args.experiment)
|
|
||||||
else:
|
|
||||||
run_local(args.config_file, args.experiment, args.job_num)
|
|
||||||
|
|
||||||
|
|
||||||
def debug_runner(run, config):
|
|
||||||
print(config)
|
|
||||||
for k in list(config.keys()):
|
|
||||||
del config[k]
|
|
||||||
import time
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
|
|
||||||
def sb3_runner(run, config):
|
|
||||||
videoC, testC, envC, algoC, pcaC = consume(config, 'video', {}), consume(config, 'test', {}), consume(config, 'env', {}), consume(config, 'algo', {}), consume(config, 'pca', {})
|
|
||||||
assert config == {}
|
|
||||||
|
|
||||||
env = DummyVecEnv([make_env_func(envC)])
|
|
||||||
if consume(videoC, 'enable', False):
|
|
||||||
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
|
||||||
|
|
||||||
assert consume(algoC, 'name') == 'PPO'
|
|
||||||
policy_name = consume(algoC, 'policy_name')
|
|
||||||
|
|
||||||
total_timesteps = config.get('run', {}).get('total_timesteps', {})
|
|
||||||
|
|
||||||
model = PPO(policy_name, env, **algoC)
|
|
||||||
|
|
||||||
if consume(pcaC, 'enable', False):
|
|
||||||
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
|
||||||
|
|
||||||
model.learn(
|
|
||||||
total_timesteps=total_timesteps,
|
|
||||||
callback=WandbCallback()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_env_func(env_conf):
|
|
||||||
conf = copy.deepcopy(env_conf)
|
|
||||||
name = consume(conf, 'name')
|
|
||||||
legacy_fancy = consume(conf, 'legacy_fancy', False)
|
|
||||||
wrappers = consume(conf, 'wrappers', [])
|
|
||||||
|
|
||||||
def func():
|
|
||||||
if legacy_fancy: # TODO: Remove when no longer needed
|
|
||||||
fancy_gym.make(name, **conf)
|
|
||||||
else:
|
|
||||||
env = gym.make(name, **conf)
|
|
||||||
|
|
||||||
# TODO: Implement wrappers
|
|
||||||
|
|
||||||
env = Monitor(env)
|
|
||||||
return env
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
Runners = {
|
|
||||||
'sb3': sb3_runner,
|
|
||||||
'debug': debug_runner
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
12
setup.py
Normal file
12
setup.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='slate',
|
||||||
|
version='1.0.0',
|
||||||
|
# url='https://github.com/mypackage.git',
|
||||||
|
# author='Author Name',
|
||||||
|
# author_email='author@gmail.com',
|
||||||
|
# description='Description of my package',
|
||||||
|
packages=['.'],
|
||||||
|
install_requires=[],
|
||||||
|
)
|
1
slate/__init__.py
Normal file
1
slate/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from slate import Slate
|
206
slate/slate.py
Normal file
206
slate/slate.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
#import fancy_gym
|
||||||
|
#from stable_baselines3 import PPO
|
||||||
|
#from stable_baselines3.common.monitor import Monitor
|
||||||
|
#from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||||
|
import wandb
|
||||||
|
from wandb.integration.sb3 import WandbCallback
|
||||||
|
#import gymnasium as gym
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import copy
|
||||||
|
import collections.abc
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
d = pdb.set_trace
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyslurm
|
||||||
|
except ImportError:
|
||||||
|
slurm_avaible = False
|
||||||
|
else:
|
||||||
|
slurm_avaible = True
|
||||||
|
|
||||||
|
# TODO: Implement Slurm
|
||||||
|
# TODO: Implement Parallel
|
||||||
|
# TODO: Implement Testing
|
||||||
|
# TODO: Implement Ablative
|
||||||
|
# TODO: Implement PCA
|
||||||
|
|
||||||
|
|
||||||
|
class Slate():
|
||||||
|
def __init__(self, runners):
|
||||||
|
self.runners = runners
|
||||||
|
|
||||||
|
def load_config(self, filename, name):
|
||||||
|
config, stack = self._load_config(filename, name)
|
||||||
|
print('[i] Merged Configs: ', stack)
|
||||||
|
self.deep_expand_vars(config, config=config)
|
||||||
|
self.consume(config, 'vars', {})
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _load_config(self, filename, name, stack=[]):
|
||||||
|
stack.append(f'{filename}:{name}')
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
docs = yaml.safe_load_all(f)
|
||||||
|
for doc in docs:
|
||||||
|
if 'name' in doc:
|
||||||
|
if doc['name'] == name:
|
||||||
|
if 'import' in doc:
|
||||||
|
imports = doc['import'].split(',')
|
||||||
|
del doc['import']
|
||||||
|
for imp in imports:
|
||||||
|
if imp[0] == ' ':
|
||||||
|
imp = imp[1:]
|
||||||
|
if imp == "$":
|
||||||
|
imp = ':DEFAULT'
|
||||||
|
rel_path, *opt = imp.split(':')
|
||||||
|
if len(opt) == 0:
|
||||||
|
nested_name = 'DEFAULT'
|
||||||
|
elif len(opt) == 1:
|
||||||
|
nested_name = opt[0]
|
||||||
|
else:
|
||||||
|
raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
|
||||||
|
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
||||||
|
child, stack = self._load_config(nested_path, nested_name, stack=stack)
|
||||||
|
doc = self.deep_update(child, doc)
|
||||||
|
return doc, stack
|
||||||
|
raise Exception(f'Unable to find experiment <{name}> in <{filename}>')
|
||||||
|
|
||||||
|
def deep_update(self, d, u):
|
||||||
|
for kstr, v in u.items():
|
||||||
|
ks = kstr.split('.')
|
||||||
|
head = d
|
||||||
|
for k in ks:
|
||||||
|
last_head = head
|
||||||
|
if k not in head:
|
||||||
|
head[k] = {}
|
||||||
|
head = head[k]
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
last_head[ks[-1]] = self.deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
last_head[ks[-1]] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
def expand_vars(self, string, **kwargs):
|
||||||
|
if isinstance(string, str):
|
||||||
|
return string.format(**kwargs)
|
||||||
|
return string
|
||||||
|
|
||||||
|
def apply_nested(self, d, f):
|
||||||
|
for k, v in d.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
self.apply_nested(v, f)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
for i, e in enumerate(v):
|
||||||
|
self.apply_nested({'PTR': d[k][i]}, f)
|
||||||
|
else:
|
||||||
|
d[k] = f(v)
|
||||||
|
|
||||||
|
def deep_expand_vars(self, dict, **kwargs):
|
||||||
|
self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))
|
||||||
|
|
||||||
|
def consume(self, conf, key, default=None):
|
||||||
|
keys_arr = key.split('.')
|
||||||
|
if len(keys_arr) == 1:
|
||||||
|
k = keys_arr[0]
|
||||||
|
if default != None:
|
||||||
|
val = conf.get(k, default)
|
||||||
|
else:
|
||||||
|
val = conf[k]
|
||||||
|
if k in conf:
|
||||||
|
del conf[k]
|
||||||
|
return val
|
||||||
|
child = conf.get(keys_arr[0], {})
|
||||||
|
child_keys = '.'.join(keys_arr[1:])
|
||||||
|
return self.consume(child, child_keys, default=default)
|
||||||
|
|
||||||
|
def run_local(self, filename, name, job_num=None):
|
||||||
|
config = self.load_config(filename, name)
|
||||||
|
if self.consume(config, 'sweep.enable', False):
|
||||||
|
sweepC = self.consume(config, 'sweep')
|
||||||
|
project = self.consume(config, 'wandb.project')
|
||||||
|
sweep_id = wandb.sweep(
|
||||||
|
sweep=sweepC,
|
||||||
|
project=project
|
||||||
|
)
|
||||||
|
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
||||||
|
wandb.agent(sweep_id, function=partial(self._run_from_sweep, config, runnerName, project, wandbC), count=config['run']['reps_per_agent'])
|
||||||
|
else:
|
||||||
|
self.consume(config, 'sweep', {})
|
||||||
|
self.run_single(config)
|
||||||
|
|
||||||
|
def _run_from_sweep(self, orig_config, runnerName, project, wandbC):
|
||||||
|
runner = self.runners[runnerName]
|
||||||
|
|
||||||
|
with wandb.init(
|
||||||
|
project=project,
|
||||||
|
**wandbC
|
||||||
|
) as run:
|
||||||
|
config = copy.deepcopy(orig_config)
|
||||||
|
self.deep_update(config, wandb.config)
|
||||||
|
runner(run, config)
|
||||||
|
|
||||||
|
assert config == {}, ('Config was not completely consumed: ', config)
|
||||||
|
|
||||||
|
def run_slurm(self, filename, name):
|
||||||
|
assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
|
||||||
|
config = self.load_config(filename, name)
|
||||||
|
slurmC = self.consume(config, 'slurm')
|
||||||
|
s_name = self.consume(slurmC, 'name')
|
||||||
|
|
||||||
|
python_script = 'main.py'
|
||||||
|
sh_lines = self.consume(slurmC, 'sh_lines', [])
|
||||||
|
if venv := self.consume(slurmC, 'venv', False):
|
||||||
|
sh_lines += [f'source activate {venv}']
|
||||||
|
sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID']
|
||||||
|
script = " && ".join(sh_lines)
|
||||||
|
|
||||||
|
num_jobs = 1
|
||||||
|
|
||||||
|
last_job_idx = num_jobs - 1
|
||||||
|
num_parallel_jobs = min(self.consume(config, 'slurm.num_parallel_jobs', num_jobs), num_jobs)
|
||||||
|
array = f'0-{last_job_idx}%{num_parallel_jobs}'
|
||||||
|
job = pyslurm.JobSubmitDescription(s_name, script=script, array=array, **config['slurm'])
|
||||||
|
job_id = job.submit()
|
||||||
|
print(f'[i] Job submitted to slurm with id {job_id}')
|
||||||
|
|
||||||
|
def run_single(self, config):
|
||||||
|
runnerName, wandbC = self.consume(config, 'runner'), self.consume(config, 'wandb', {})
|
||||||
|
|
||||||
|
runner = Runners[runnerName]
|
||||||
|
|
||||||
|
with wandb.init(
|
||||||
|
project=self.consume(wandbC, 'project'),
|
||||||
|
config=config,
|
||||||
|
**wandbC
|
||||||
|
) as run:
|
||||||
|
runner(run, config)
|
||||||
|
|
||||||
|
assert config == {}, ('Config was not completely consumed: ', config)
|
||||||
|
|
||||||
|
def from_args(self):
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("config_file", nargs='?', default=None)
|
||||||
|
parser.add_argument("experiment", nargs='?', default='DEFAULT')
|
||||||
|
parser.add_argument("-s", "--slurm", action="store_true")
|
||||||
|
parser.add_argument("-w", "--worker", action="store_true")
|
||||||
|
parser.add_argument("-j", "--job_num", default=None)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.worker:
|
||||||
|
raise Exception('Not yet implemented')
|
||||||
|
|
||||||
|
assert args.config_file != None, 'Need to supply config file.'
|
||||||
|
if args.slurm:
|
||||||
|
self.run_slurm(args.config_file, args.experiment)
|
||||||
|
else:
|
||||||
|
self.run_local(args.config_file, args.experiment, args.job_num)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
raise Exception('You are using it wrong...')
|
Loading…
Reference in New Issue
Block a user