from slate import Slate

import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym
import copy

PCA = None


def debug_runner(slate, run, config):
    print(config)
    for k in list(config.keys()):
        del config[k]
    import time
    time.sleep(10)


def sb3_runner(slate, run, config):
    videoC, testC, envC, algoC, pcaC = slate.consume(config, 'video', {}), slate.consume(config, 'test', {}), slate.consume(config,
                                                                                                                            'env', {}), slate.consume(config, 'algo', {}), slate.consume(config, 'pca', {})
    assert config == {}

    env = DummyVecEnv([make_env_func(slate, envC)])
    if slate.consume(videoC, 'enable', False):
        env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])

    assert slate.consume(algoC, 'name') == 'PPO'
    policy_name = slate.consume(algoC, 'policy_name')

    total_timesteps = config.get('run', {}).get('total_timesteps', {})

    model = PPO(policy_name, env, **algoC)

    if slate.consume(pcaC, 'enable', False):
        model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)

    model.learn(
        total_timesteps=total_timesteps,
        callback=WandbCallback()
    )


def make_env_func(slate, env_conf):
    conf = copy.deepcopy(env_conf)
    name = slate.consume(conf, 'name')
    legacy_fancy = slate.consume(conf, 'legacy_fancy', False)
    wrappers = slate.consume(conf, 'wrappers', [])

    def func():
        if legacy_fancy:  # TODO: Remove when no longer needed
            fancy_gym.make(name, **conf)
        else:
            env = gym.make(name, **conf)

        # TODO: Implement wrappers

        env = Monitor(env)
        return env
    return func


runners = {
    'sb3': sb3_runner,
    'debug': debug_runner
}

if __name__ == '__main__':
    slate = Slate(runners)
    slate.from_args()