commit 2d2f8f9f6e9786c9c19b307044f40b28a207f63e Author: Dominik Roth Date: Wed Jul 5 15:02:53 2023 +0200 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9f7550b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +.venv diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..4b4179c --- /dev/null +++ b/config.yaml @@ -0,0 +1,77 @@ +name: DEFAULT +project: alpha + +slurm: + name: False + partition: "single" + num_parallel_jobs: 64 + cpus-per-task: 1 + mem-per-cpu: 3000 + time: 1440 # in minutes + +repetitions: 3 +agents_per_job: 3 +reps_per_agent: 1 + +total_timesteps: 10000 + +video: + enable: True + length: 3000 + frequency: 100 + +test: + enable: True + length: 3000 + frequency: 100 # 32 # 10 + deterministic: Both + num_envs: 1 + +env: + name: BoxPushingDense-v0 + legacy_fancy: True + normalize_obs: True + normalize_rew: True + num_envs: 1 + env_args: + more_obs:True + +algo: + name: PPO + policy_name: MlpPolicy + n_steps: 4096 + vf_coef: 1.0e-5 + learning_rate: 5.0e-5 + batch_size: 512 + action_coef: 0 + ent_coef: 0 + normalize_advantage: False # True + +pca: + enable: False + window: 64 + skip_conditioning: True + Base_Noise: WHITE + init_std: 1.0 +--- +sweep: + enable: True + method: random, + metric: + goal: minimize, + name: score + parameters: + lel: lol +--- +ablative: + task: + add_time_awareness: [True] + add_normalize_obs: [False] + env_args: + more_obs: [True] + algorithm: + network: + #ent_coef: [0, 0.001, 0.003] + normalize_advantage: [True] + distribution: + init_std: [0.5] diff --git a/main.py b/main.py new file mode 100644 index 0000000..7ac67cd --- /dev/null +++ b/main.py @@ -0,0 +1,111 @@ +import fancy_gym +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +import wandb +from wandb.integration.sb3 import WandbCallback +import gymnasium as gym +import yaml +import os +import collections.abc + +PCA = None + + +def load_config(filename, name): + with open(filename, 'r') as f: + docs = yaml.safe_load_all(f) + for doc in docs: + if 'name' in doc: + if doc['name'] == name: + if 'import' in doc: + imports = reversed(doc['import'].split(',')) + del doc['import'] + for imp in imports: + rel_path, *opt = imp.split(':') + if len(opt) == 0: + nested_name = 'DEFAULT' + elif len(opt) == 1: + nested_name = opt[0] + else: + raise Exception() + nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename + child = load_config(nested_path, nested_name) + doc = deep_update(child, doc) + return doc + + +def deep_update(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = deep_update(d.get(k, {}), v) + else: + d[k] = v + return d + + +def run(filename, name): + config = load_config(filename, name) + if 'sweep' in config and config['sweep']['enable']: + sweepC = config['sweep'] + del sweepC['enable'] + sweep_id = wandb.sweep( + sweep=sweepC, + project=config['project'] + ) + wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent']) + else: + run_single(config) + + +def run_single(config): + videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {}) + + with wandb.init( + project=config['project'], + config=config, + sync_tensorboard=True, + monitor_gym=True, + save_code=True, + ) as run: + env = DummyVecEnv([make_env_func(envC)]) + if videoC.get('enable', False): + env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) + + assert algoC['name'] == 'PPO' + del algoC['name'] + policy_name = algoC['policy_name'] + del algoC['policy_name'] + model = PPO(policy_name env, **algo) + + if pcaC.get('enable', False): + del pcaC['enable'] + model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) + + model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback() + ) + + +def make_env_func(env_conf): + def func(): + legacy_fancy = env_conf.get('legacy_fancy', False) + del env_conf['name'] + if 'legacy_fancy' in env_conf: + del env_conf['legacy_fancy'] + if legacy_fancy: # TODO: Remove when no longer needed + fancy_gym.make(env_conf['name'], **env_conf) + else: + env = gym.make(env_conf['name'], **env_conf) + env = Monitor(env) + return env + return func + + +def main(): + run() + + +if __name__ == '__main__': + main()