Slate/example.py

74 lines
2.1 KiB
Python

from slate import Slate
import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym
import copy
PCA = None
def debug_runner(slate, run, config):
print(config)
for k in list(config.keys()):
del config[k]
import time
time.sleep(10)
def sb3_runner(slate, run, config):
videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config,
'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {})
assert config == {}
env = DummyVecEnv([make_env_func(slate, envC)])
if slate.consume(videoC, 'enable', False):
env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])
assert slate.consume(algoC, 'name') == 'PPO'
policy_name = slate.consume(algoC, 'policy_name')
total_timesteps = config.get('run', {}).get('total_timesteps', {})
model = PPO(policy_name, env, **algoC)
if slate.consume(pcaC, 'enable', False):
model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)
model.learn(
total_timesteps=total_timesteps,
callback=WandbCallback()
)
def make_env_func(slate, env_conf):
conf = copy.deepcopy(env_conf)
name = slate.consume(conf, 'name')
legacy_fancy = slate.consume(conf, 'legacy_fancy', False)
wrappers = slate.consume(conf, 'wrappers', [])
def func():
if legacy_fancy: # TODO: Remove when no longer needed
fancy_gym.make(name, **conf)
else:
env = gym.make(name, **conf)
# TODO: Implement wrappers
env = Monitor(env)
return env
return func
runners = {
'sb3': sb3_runner,
'debug': debug_runner
}
if __name__ == '__main__':
slate = Slate(runners)
slate.from_args()