Initial commit

This commit is contained in:
Dominik Moritz Roth 2023-07-05 15:02:53 +02:00
commit 2d2f8f9f6e
3 changed files with 190 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__
.venv

77
config.yaml Normal file
View File

@ -0,0 +1,77 @@
name: DEFAULT
project: alpha
slurm:
name: False
partition: "single"
num_parallel_jobs: 64
cpus-per-task: 1
mem-per-cpu: 3000
time: 1440 # in minutes
repetitions: 3
agents_per_job: 3
reps_per_agent: 1
total_timesteps: 10000
video:
enable: True
length: 3000
frequency: 100
test:
enable: True
length: 3000
frequency: 100 # 32 # 10
deterministic: Both
num_envs: 1
env:
name: BoxPushingDense-v0
legacy_fancy: True
normalize_obs: True
normalize_rew: True
num_envs: 1
env_args:
more_obs:True
algo:
name: PPO
policy_name: MlpPolicy
n_steps: 4096
vf_coef: 1.0e-5
learning_rate: 5.0e-5
batch_size: 512
action_coef: 0
ent_coef: 0
normalize_advantage: False # True
pca:
enable: False
window: 64
skip_conditioning: True
Base_Noise: WHITE
init_std: 1.0
---
sweep:
enable: True
method: random,
metric:
goal: minimize,
name: score
parameters:
lel: lol
---
ablative:
task:
add_time_awareness: [True]
add_normalize_obs: [False]
env_args:
more_obs: [True]
algorithm:
network:
#ent_coef: [0, 0.001, 0.003]
normalize_advantage: [True]
distribution:
init_std: [0.5]

111
main.py Normal file
View File

@ -0,0 +1,111 @@
import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym
import yaml
import os
import collections.abc
PCA = None
def load_config(filename, name):
with open(filename, 'r') as f:
docs = yaml.safe_load_all(f)
for doc in docs:
if 'name' in doc:
if doc['name'] == name:
if 'import' in doc:
imports = reversed(doc['import'].split(','))
del doc['import']
for imp in imports:
rel_path, *opt = imp.split(':')
if len(opt) == 0:
nested_name = 'DEFAULT'
elif len(opt) == 1:
nested_name = opt[0]
else:
raise Exception()
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
child = load_config(nested_path, nested_name)
doc = deep_update(child, doc)
return doc
def deep_update(d, u):
for k, v in u.items():
if isinstance(v, collections.abc.Mapping):
d[k] = deep_update(d.get(k, {}), v)
else:
d[k] = v
return d
def run(filename, name):
config = load_config(filename, name)
if 'sweep' in config and config['sweep']['enable']:
sweepC = config['sweep']
del sweepC['enable']
sweep_id = wandb.sweep(
sweep=sweepC,
project=config['project']
)
wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
else:
run_single(config)
def run_single(config):
videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {})
with wandb.init(
project=config['project'],
config=config,
sync_tensorboard=True,
monitor_gym=True,
save_code=True,
) as run:
env = DummyVecEnv([make_env_func(envC)])
if videoC.get('enable', False):
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
assert algoC['name'] == 'PPO'
del algoC['name']
policy_name = algoC['policy_name']
del algoC['policy_name']
model = PPO(policy_name env, **algo)
if pcaC.get('enable', False):
del pcaC['enable']
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback()
)
def make_env_func(env_conf):
def func():
legacy_fancy = env_conf.get('legacy_fancy', False)
del env_conf['name']
if 'legacy_fancy' in env_conf:
del env_conf['legacy_fancy']
if legacy_fancy: # TODO: Remove when no longer needed
fancy_gym.make(env_conf['name'], **env_conf)
else:
env = gym.make(env_conf['name'], **env_conf)
env = Monitor(env)
return env
return func
def main():
run()
if __name__ == '__main__':
main()