Initial commit
This commit is contained in:
commit
2d2f8f9f6e
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
__pycache__
|
||||||
|
.venv
|
77
config.yaml
Normal file
77
config.yaml
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
name: DEFAULT
|
||||||
|
project: alpha
|
||||||
|
|
||||||
|
slurm:
|
||||||
|
name: False
|
||||||
|
partition: "single"
|
||||||
|
num_parallel_jobs: 64
|
||||||
|
cpus-per-task: 1
|
||||||
|
mem-per-cpu: 3000
|
||||||
|
time: 1440 # in minutes
|
||||||
|
|
||||||
|
repetitions: 3
|
||||||
|
agents_per_job: 3
|
||||||
|
reps_per_agent: 1
|
||||||
|
|
||||||
|
total_timesteps: 10000
|
||||||
|
|
||||||
|
video:
|
||||||
|
enable: True
|
||||||
|
length: 3000
|
||||||
|
frequency: 100
|
||||||
|
|
||||||
|
test:
|
||||||
|
enable: True
|
||||||
|
length: 3000
|
||||||
|
frequency: 100 # 32 # 10
|
||||||
|
deterministic: Both
|
||||||
|
num_envs: 1
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: BoxPushingDense-v0
|
||||||
|
legacy_fancy: True
|
||||||
|
normalize_obs: True
|
||||||
|
normalize_rew: True
|
||||||
|
num_envs: 1
|
||||||
|
env_args:
|
||||||
|
more_obs:True
|
||||||
|
|
||||||
|
algo:
|
||||||
|
name: PPO
|
||||||
|
policy_name: MlpPolicy
|
||||||
|
n_steps: 4096
|
||||||
|
vf_coef: 1.0e-5
|
||||||
|
learning_rate: 5.0e-5
|
||||||
|
batch_size: 512
|
||||||
|
action_coef: 0
|
||||||
|
ent_coef: 0
|
||||||
|
normalize_advantage: False # True
|
||||||
|
|
||||||
|
pca:
|
||||||
|
enable: False
|
||||||
|
window: 64
|
||||||
|
skip_conditioning: True
|
||||||
|
Base_Noise: WHITE
|
||||||
|
init_std: 1.0
|
||||||
|
---
|
||||||
|
sweep:
|
||||||
|
enable: True
|
||||||
|
method: random,
|
||||||
|
metric:
|
||||||
|
goal: minimize,
|
||||||
|
name: score
|
||||||
|
parameters:
|
||||||
|
lel: lol
|
||||||
|
---
|
||||||
|
ablative:
|
||||||
|
task:
|
||||||
|
add_time_awareness: [True]
|
||||||
|
add_normalize_obs: [False]
|
||||||
|
env_args:
|
||||||
|
more_obs: [True]
|
||||||
|
algorithm:
|
||||||
|
network:
|
||||||
|
#ent_coef: [0, 0.001, 0.003]
|
||||||
|
normalize_advantage: [True]
|
||||||
|
distribution:
|
||||||
|
init_std: [0.5]
|
111
main.py
Normal file
111
main.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import fancy_gym
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
||||||
|
import wandb
|
||||||
|
from wandb.integration.sb3 import WandbCallback
|
||||||
|
import gymnasium as gym
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
PCA = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(filename, name):
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
docs = yaml.safe_load_all(f)
|
||||||
|
for doc in docs:
|
||||||
|
if 'name' in doc:
|
||||||
|
if doc['name'] == name:
|
||||||
|
if 'import' in doc:
|
||||||
|
imports = reversed(doc['import'].split(','))
|
||||||
|
del doc['import']
|
||||||
|
for imp in imports:
|
||||||
|
rel_path, *opt = imp.split(':')
|
||||||
|
if len(opt) == 0:
|
||||||
|
nested_name = 'DEFAULT'
|
||||||
|
elif len(opt) == 1:
|
||||||
|
nested_name = opt[0]
|
||||||
|
else:
|
||||||
|
raise Exception()
|
||||||
|
nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
|
||||||
|
child = load_config(nested_path, nested_name)
|
||||||
|
doc = deep_update(child, doc)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
def deep_update(d, u):
|
||||||
|
for k, v in u.items():
|
||||||
|
if isinstance(v, collections.abc.Mapping):
|
||||||
|
d[k] = deep_update(d.get(k, {}), v)
|
||||||
|
else:
|
||||||
|
d[k] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def run(filename, name):
|
||||||
|
config = load_config(filename, name)
|
||||||
|
if 'sweep' in config and config['sweep']['enable']:
|
||||||
|
sweepC = config['sweep']
|
||||||
|
del sweepC['enable']
|
||||||
|
sweep_id = wandb.sweep(
|
||||||
|
sweep=sweepC,
|
||||||
|
project=config['project']
|
||||||
|
)
|
||||||
|
wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
|
||||||
|
else:
|
||||||
|
run_single(config)
|
||||||
|
|
||||||
|
|
||||||
|
def run_single(config):
|
||||||
|
videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {})
|
||||||
|
|
||||||
|
with wandb.init(
|
||||||
|
project=config['project'],
|
||||||
|
config=config,
|
||||||
|
sync_tensorboard=True,
|
||||||
|
monitor_gym=True,
|
||||||
|
save_code=True,
|
||||||
|
) as run:
|
||||||
|
env = DummyVecEnv([make_env_func(envC)])
|
||||||
|
if videoC.get('enable', False):
|
||||||
|
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
||||||
|
|
||||||
|
assert algoC['name'] == 'PPO'
|
||||||
|
del algoC['name']
|
||||||
|
policy_name = algoC['policy_name']
|
||||||
|
del algoC['policy_name']
|
||||||
|
model = PPO(policy_name env, **algo)
|
||||||
|
|
||||||
|
if pcaC.get('enable', False):
|
||||||
|
del pcaC['enable']
|
||||||
|
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=config["total_timesteps"],
|
||||||
|
callback=WandbCallback()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_env_func(env_conf):
|
||||||
|
def func():
|
||||||
|
legacy_fancy = env_conf.get('legacy_fancy', False)
|
||||||
|
del env_conf['name']
|
||||||
|
if 'legacy_fancy' in env_conf:
|
||||||
|
del env_conf['legacy_fancy']
|
||||||
|
if legacy_fancy: # TODO: Remove when no longer needed
|
||||||
|
fancy_gym.make(env_conf['name'], **env_conf)
|
||||||
|
else:
|
||||||
|
env = gym.make(env_conf['name'], **env_conf)
|
||||||
|
env = Monitor(env)
|
||||||
|
return env
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user