import fancy_gym
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder
import wandb
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym
import yaml
import os
import collections.abc

PCA = None


def load_config(filename, name):
    with open(filename, 'r') as f:
        docs = yaml.safe_load_all(f)
        for doc in docs:
            if 'name' in doc:
                if doc['name'] == name:
                    if 'import' in doc:
                        imports = reversed(doc['import'].split(','))
                        del doc['import']
                        for imp in imports:
                            rel_path, *opt = imp.split(':')
                            if len(opt) == 0:
                                nested_name = 'DEFAULT'
                            elif len(opt) == 1:
                                nested_name = opt[0]
                            else:
                                raise Exception()
                            nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
                            child = load_config(nested_path, nested_name)
                            doc = deep_update(child, doc)
                    return doc


def deep_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = deep_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def run(filename, name):
    config = load_config(filename, name)
    if 'sweep' in config and config['sweep']['enable']:
        sweepC = config['sweep']
        del sweepC['enable']
        sweep_id = wandb.sweep(
            sweep=sweepC,
            project=config['project']
        )
        wandb.agent(sweep_id, function=run_single, count=config['reps_per_agent'])
    else:
        run_single(config)


def run_single(config):
    videoC, testC, envC, algoC, pcaC = config.get('video', {}), config.get('test', {}), config.get('env', {}), config.get('algo', {}), config.get('pca', {})

    with wandb.init(
        project=config['project'],
        config=config,
        sync_tensorboard=True,
        monitor_gym=True,
        save_code=True,
    ) as run:
        env = DummyVecEnv([make_env_func(envC)])
        if videoC.get('enable', False):
            env = VecVideoRecorder(env, f"videos/{run.id}", record_video_trigger=lambda x: x % videoC['frequency'] == 0, video_length=videoC['length'])

        assert algoC['name'] == 'PPO'
        del algoC['name']
        policy_name = algoC['policy_name']
        del algoC['policy_name']
        model = PPO(policy_name env, **algo)

        if pcaC.get('enable', False):
            del pcaC['enable']
            model.policy.action_dist = PCA(model.policy.action_space.shape, **pcaC)

        model.learn(
            total_timesteps=config["total_timesteps"],
            callback=WandbCallback()
        )


def make_env_func(env_conf):
    def func():
        legacy_fancy = env_conf.get('legacy_fancy', False)
        del env_conf['name']
        if 'legacy_fancy' in env_conf:
            del env_conf['legacy_fancy']
        if legacy_fancy:  # TODO: Remove when no longer needed
            fancy_gym.make(env_conf['name'], **env_conf)
        else:
            env = gym.make(env_conf['name'], **env_conf)
        env = Monitor(env)
        return env
    return func


def main():
    run()


if __name__ == '__main__':
    main()