from slate import Slate #import fancy_gym from stable_baselines3 import PPO from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder from wandb.integration.sb3 import WandbCallback import gymnasium as gym import copy PCA = None def debug_runner(slate, run, config): print(config) for k in list(config.keys()): del config[k] import time time.sleep(10) def sb3_runner(slate, run, config): videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config, 'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {}) assert config == {} env = DummyVecEnv([make_env_func(slate, envC)]) if slate.consume(videoC, 'enable', False): env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length']) assert slate.consume(algoC, 'name') == 'PPO' policy_name = slate.consume(algoC, 'policy_name') total_timesteps = config.get('run', {}).get('total_timesteps', {}) model = PPO(policy_name, env, **algoC) if slate.consume(pcaC, 'enable', False): model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC) model.learn( total_timesteps=total_timesteps, callback=WandbCallback() ) def make_env_func(slate, env_conf): conf = copy.deepcopy(env_conf) name = slate.consume(conf, 'name') legacy_fancy = slate.consume(conf, 'legacy_fancy', False) wrappers = slate.consume(conf, 'wrappers', []) def func(): if legacy_fancy: # TODO: Remove when no longer needed fancy_gym.make(name, **conf) else: env = gym.make(name, **conf) # TODO: Implement wrappers env = Monitor(env) return env return func runners = { 'sb3': sb3_runner, 'debug': debug_runner } if __name__ == '__main__': slate = Slate(runners) slate.from_args()