74 lines
2.1 KiB
Python
74 lines
2.1 KiB
Python
from slate import Slate
|
|
|
|
#import fancy_gym
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
|
|
from wandb.integration.sb3 import WandbCallback
|
|
import gymnasium as gym
|
|
import copy
|
|
|
|
PCA = None
|
|
|
|
|
|
def debug_runner(slate, run, config):
|
|
print(config)
|
|
for k in list(config.keys()):
|
|
del config[k]
|
|
import time
|
|
time.sleep(10)
|
|
|
|
|
|
def sb3_runner(slate, run, config):
|
|
videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config,
|
|
'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {})
|
|
assert config == {}
|
|
|
|
env = DummyVecEnv([make_env_func(slate, envC)])
|
|
if slate.consume(videoC, 'enable', False):
|
|
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
|
|
|
|
assert slate.consume(algoC, 'name') == 'PPO'
|
|
policy_name = slate.consume(algoC, 'policy_name')
|
|
|
|
total_timesteps = config.get('run', {}).get('total_timesteps', {})
|
|
|
|
model = PPO(policy_name, env, **algoC)
|
|
|
|
if slate.consume(pcaC, 'enable', False):
|
|
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
|
|
|
|
model.learn(
|
|
total_timesteps=total_timesteps,
|
|
callback=WandbCallback()
|
|
)
|
|
|
|
|
|
def make_env_func(slate, env_conf):
|
|
conf = copy.deepcopy(env_conf)
|
|
name = slate.consume(conf, 'name')
|
|
legacy_fancy = slate.consume(conf, 'legacy_fancy', False)
|
|
wrappers = slate.consume(conf, 'wrappers', [])
|
|
|
|
def func():
|
|
if legacy_fancy: # TODO: Remove when no longer needed
|
|
fancy_gym.make(name, **conf)
|
|
else:
|
|
env = gym.make(name, **conf)
|
|
|
|
# TODO: Implement wrappers
|
|
|
|
env = Monitor(env)
|
|
return env
|
|
return func
|
|
|
|
|
|
runners = {
|
|
'sb3': sb3_runner,
|
|
'debug': debug_runner
|
|
}
|
|
|
|
if __name__ == '__main__':
|
|
slate = Slate(runners)
|
|
slate.from_args()
|