import wandb
import yaml
import os
import math
import random
import copy
import re
import itertools
from collections.abc import *
from functools import partial
from multiprocessing import Process
from threading import Thread
import git
import datetime
from pprint import pprint

import pdb
d = pdb.set_trace

REQUIRE_CONFIG_CONSUMED = False
DEFAULT_START_METHOD = 'fork'
DEFAULT_REINIT = True

Parallelization_Primitive = Process  # Thread

try:
    import pyslurm
except ImportError:
    slurm_avaible = False
    print('[!] Slurm not avaible.')
else:
    slurm_avaible = True

# TODO: Implement Testing
# TODO: Implement Ablative


class Slate():
    def __init__(self, runners):
        self.runners = {
            'void': Void_Runner,
            'printConfig': Print_Config_Runner,
            'pdb': PDB_Runner,
        }
        self.runners.update(runners)
        self._version = False
        self.sweep_id = None

    def load_config(self, filename, name):
        config, stack = self._load_config(filename, name)
        print('[i] Merged Configs: ', stack)
        self._config = copy.deepcopy(config)
        self.consume(config, 'vars', {})
        return config

    def _load_config(self, filename, name, stack=[]):
        stack.append(f'{filename}:{name}')
        with open(filename, 'r') as f:
            docs = yaml.safe_load_all(f)
            for doc in docs:
                if 'name' in doc:
                    if doc['name'] == name:
                        if 'import' in doc:
                            imports = doc['import'].split(',')
                            del doc['import']
                            for imp in imports:
                                if imp[0] == ' ':
                                    imp = imp[1:]
                                if imp == "$":
                                    imp = ':DEFAULT'
                                rel_path, *opt = imp.split(':')
                                if len(opt) == 0:
                                    nested_name = 'DEFAULT'
                                elif len(opt) == 1:
                                    nested_name = opt[0]
                                else:
                                    raise Exception('Malformed import statement. Must be <import file:exp>, <import :exp>, <import file> for file:DEFAULT or <import $> for :DEFAULT.')
                                nested_path = os.path.normpath(os.path.join(os.path.dirname(filename), rel_path)) if len(rel_path) else filename
                                child, stack = self._load_config(nested_path, nested_name, stack=stack)
                                doc = self.deep_update(child, doc)
                        return doc, stack
            raise Exception(f'Unable to find experiment <{name}> in <{filename}>')

    def deep_update(self, d, u, traverse_dot_notation=True):
        for kstr, v in u.items():
            if traverse_dot_notation:
                ks = kstr.split('.')
            else:
                ks = [kstr]
            head = d
            for k in ks:
                if k in ['parameters']:
                    traverse_dot_notation = False
                last_head = head
                if k not in head:
                    head[k] = {}
                head = head[k]
            if isinstance(v, Mapping):
                last_head[ks[-1]] = self.deep_update(d.get(k, {}), v, traverse_dot_notation=traverse_dot_notation)
            else:
                last_head[ks[-1]] = v
        return d

    def expand_vars(self, string, delta_desc='BASE', **kwargs):
        if isinstance(string, str):
            rand = int(random.random()*99999999)
            if string == '{rand}':
                return rand
            return string.format(delta_desc=delta_desc, **kwargs, rand=rand)
        return string

    def apply_nested(self, d, f):
        for k, v in d.items():
            if isinstance(v, dict):
                self.apply_nested(v, f)
            elif isinstance(v, list):
                for i, e in enumerate(v):
                    ptr = {'PTR': d[k][i]}
                    self.apply_nested(ptr, f)
                    d[k][i] = ptr['PTR']
            else:
                d[k] = f(v)

    def deep_expand_vars(self, dict, **kwargs):
        self.apply_nested(dict, lambda x: self.expand_vars(x, **kwargs))

    def consume(self, conf, key, default=None, expand=False, **kwargs):
        if key == '':
            if expand:
                self.deep_expand_vars(conf, config=self._config, **kwargs)
            elif type(conf) == str:
                while conf.find('{') != -1:
                    conf = self.expand_vars(conf, config=self._config, **kwargs)
            return conf
        keys_arr = key.split('.')
        if len(keys_arr) == 1:
            k = keys_arr[0]
            if default != None:
                if isinstance(conf, Mapping):
                    val = conf.get(k, default)
                else:
                    if default != None:
                        return default
                    raise Exception('')
            else:
                val = conf[k]
            if k in conf:
                del conf[k]

            if expand:
                self.deep_expand_vars(val, config=self._config, **kwargs)
            elif type(val) == str:
                while val.find('{') != -1:
                    val = self.expand_vars(val, config=self._config, **kwargs)

            return val
        child = conf.get(keys_arr[0], {})
        child_keys = '.'.join(keys_arr[1:])
        return self.consume(child, child_keys, default=default, expand=expand, **kwargs)

    def get_version(self):
        if not self._version:
            repo = git.Repo(search_parent_directories=True)
            sha = repo.head.object.hexsha
            self._version = sha
        return self._version

    def _calc_num_jobs(self, schedC, num_conv_versions):
        schedulerC = copy.deepcopy(schedC)
        reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
        agents_per_job = self.consume(schedulerC, 'agents_per_job', 1)
        reps_per_agent = self.consume(schedulerC, 'reps_per_agent', 1)
        reps_per_job = reps_per_agent * agents_per_job
        jobs_needed = math.ceil(reps / reps_per_job)
        return jobs_needed

    def _reps_for_job(self, schedC, job_id, num_conv_versions):
        schedulerC = copy.deepcopy(schedC)
        num_jobs = self._calc_num_jobs(schedulerC, num_conv_versions)
        reps = self.consume(schedulerC, 'repetitions', self.consume(schedulerC, 'reps_per_version', 1)*num_conv_versions)
        if job_id == None:
            return list(range(0, reps))
        reps_for_job = [[] for i in range(num_jobs)]
        for i in range(reps):
            reps_for_job[i % num_jobs].append(i)
        return reps_for_job[job_id]

    def run_local(self, filename, name, job_id, sweep_id):
        config = self.load_config(filename, name)
        num_conv_versions = self._get_num_conv_versions(config)
        schedulerC = copy.deepcopy(config.get('scheduler', {}))
        rep_ids = self._reps_for_job(schedulerC, job_id, num_conv_versions)
        self.sweep_id = sweep_id
        self._init_sweep(config)
        self._fork_processes(config, rep_ids)

    def run_slurm(self, filename, name):
        assert slurm_avaible, 'pyslurm does not seem to be installed on this system.'
        config = self.load_config(filename, name)
        slurmC = self.consume(config, 'slurm', expand=True)
        schedC = self.consume(config, 'scheduler')
        s_name = self.consume(slurmC, 'name')

        num_conv_versions = self._get_num_conv_versions(config)

        # Pre Validation
        runnerName, _ = self.consume(config, 'runner'), self.consume(copy.deepcopy(config), 'wandb', {}, expand=True)
        if self.consume(slurmC, 'pre_validate', True):
            Runner = self.runners[runnerName]
            runner = Runner(self, config)
            runner.setup()

        self._init_sweep(config)
        self.consume(config, 'wandb')

        python_script = 'main.py'
        sh_lines = ['#!/bin/bash']
        sh_lines += self.consume(slurmC, 'sh_lines', [])
        if venv := self.consume(slurmC, 'venv', False):
            sh_lines += [f'source activate {venv}']
        sh_lines += [f'python3 {python_script} {filename} {name} -j $SLURM_ARRAY_TASK_ID --sweep_id {self.sweep_id}']
        script = "\n".join(sh_lines)

        num_jobs = self._calc_num_jobs(schedC, num_conv_versions)

        last_job_idx = num_jobs - 1
        num_parallel_jobs = min(self.consume(slurmC, 'num_parallel_jobs', num_jobs), num_jobs)
        array = f'0-{last_job_idx}%{num_parallel_jobs}'
        job = pyslurm.JobSubmitDescription(name=s_name, script=script, array=array, **slurmC)
        job_id = job.submit()
        print(f'[>] Job submitted to slurm with id {job_id}')
        with open('job_hist.log', 'a') as f:
            f.write(f'{filename}:{name} submitted to slurm with ids {job_id}_0-{num_jobs} on [git:{self.get_version()}] at {datetime.datetime.now()}\n')

    def _fork_processes(self, config, rep_ids):
        schedC = self.consume(config, 'scheduler', {})
        agents_per_job = self.consume(schedC, 'agents_per_job', 1)
        reps_per_agent = self.consume(schedC, 'reps_per_agent', 1)

        node_reps = len(rep_ids)
        num_p = min(agents_per_job, math.ceil(node_reps / reps_per_agent))

        if num_p == 1:
            print('[i] Running within main thread')
            self._run_process(config, rep_ids=rep_ids, p_ind=0)
            return

        procs = []

        reps_done = 0

        for p in range(num_p):
            print(f'[i] Spawning seperate thread/process ({p+1}/{num_p})')
            num_reps = min(node_reps - reps_done, reps_per_agent)
            proc_rep_ids = [rep_ids[i] for i in list(range(reps_done, reps_done+num_reps))]
            proc = Parallelization_Primitive(target=partial(self._run_process, config, rep_ids=proc_rep_ids, p_ind=p))
            proc.start()
            procs.append(proc)
            reps_done += num_reps

        for proc in procs:
            proc.join()
        print(f'[i] All threads/processes have terminated')

    def _init_sweep(self, config):
        if self.sweep_id == None and self.consume(config, 'sweep.enable', False):
            sweepC = self.consume(config, 'sweep')
            wandbC = copy.deepcopy(config['wandb'])
            project = self.consume(wandbC, 'project')

            self.sweep_id = wandb.sweep(
                sweep=sweepC,
                project=project
            )

    def _run_process(self, orig_config, rep_ids, p_ind):
        config = copy.deepcopy(orig_config)
        if self.consume(config, 'sweep.enable', False):
            wandbC = copy.deepcopy(config['wandb'])
            wandb.agent(self.sweep_id, function=partial(self._run_from_sweep, config, p_ind=p_ind), count=len(rep_ids))
        else:
            self.consume(config, 'sweep', {})
            self._run_single(config, rep_ids, p_ind=p_ind)

    def _run_single(self, orig_config, rep_ids, p_ind):
        print(f'[P{p_ind}] I will work on reps {rep_ids}')
        runnerName = self.consume(orig_config, 'runner')
        project = self.consume(orig_config, 'wandb.project', orig_config.get('project', orig_config.get('name')))

        Runner = self.runners[runnerName]

        for r in rep_ids:
            config = copy.deepcopy(orig_config)
            runnerConf = self._make_config_for_run(config, r)
            wandbC = self.consume(runnerConf, 'wandb', {}, expand=True, delta_desc=runnerConf.pop('delta_desc', 'BASE'))
            with wandb.init(
                project=project,
                config=copy.deepcopy(runnerConf),
                reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
                settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
                **wandbC
            ) as run:
                runner = Runner(self, runnerConf)
                runner.setup()
                runner.run(run)

            if runnerConf != {}:
                msg = ('Config was not completely consumed: ', runnerConf)
                if REQUIRE_CONFIG_CONSUMED:
                    raise Exception(msg)
                else:
                    print(msg)
        orig_config = {}

    def _run_from_sweep(self, orig_config, p_ind):
        runnerName, wandbC = self.consume(orig_config, 'runner'), self.consume(orig_config, 'wandb', {}, expand=True)
        project = self.consume(wandbC, 'project')

        Runner = self.runners[runnerName]

        with wandb.init(
            project=project,
            reinit=self.consume(wandbC, 'reinit', DEFAULT_REINIT),
            settings=wandb.Settings(**self.consume(wandbC, 'settings', {})),
            **wandbC
        ) as run:
            config = copy.deepcopy(orig_config)
            self.deep_update(config, wandb.config)
            run.config = copy.deepcopy(config)
            runner = Runner(self, config)
            runner.setup()
            runner.run(run)

            if config != {}:
                msg = ('Config was not completely consumed: ', config)
                if REQUIRE_CONFIG_CONSUMED:
                    raise Exception(msg)
                else:
                    print(msg)
        orig_config = {}

    def _make_configs_for_runs(self, config):
        c = copy.deepcopy(config)

        grid_versions = self._make_grid_versions(c)
        all_versions = self._make_ablative_versions(c, grid_versions)

        return all_versions

    def _get_num_conv_versions(self, config):
        return len(self._make_configs_for_runs(config))

    def _make_config_for_run(self, config, r):
        all_versions = self._make_configs_for_runs(config)

        i = r % len(all_versions)
        print(f'[d] Running version {i}/{len(all_versions)} in run {r}')
        cur_conf = all_versions[i]
        if 'ablative' in cur_conf:
            del cur_conf['ablative']
        return cur_conf

    def _make_grid_versions(self, config):
        if 'grid' in config:
            return params_combine(config, 'grid', itertools.product)
        return [config]

    def _make_ablative_versions(self, config, grid_versions):
        if 'ablative' in config:
            return grid_versions + ablative_expand(grid_versions)
        else:
            return grid_versions

    def from_args(self):
        import argparse

        parser = argparse.ArgumentParser()
        parser.add_argument("config_file", nargs='?', default=None)
        parser.add_argument("experiment", nargs='?', default='DEFAULT')
        parser.add_argument("-s", "--slurm", action="store_true")
        parser.add_argument("-w", "--worker", action="store_true")
        parser.add_argument("-j", "--job_id", default=None, type=int)
        parser.add_argument("--sweep_id", default=None, type=str)

        args = parser.parse_args()

        print(f'[i] I have job_id {args.job_id}')
        print(f'[i] Running on version [git:{self.get_version()}]')

        if args.worker:
            raise Exception('Not yet implemented')

        assert args.config_file != None, 'Need to supply config file.'
        if args.slurm:
            self.run_slurm(args.config_file, args.experiment)
        else:
            self.run_local(args.config_file, args.experiment, args.job_id, args.sweep_id)


def params_combine(config: dict, key: str, iter_func):
    if iter_func is None:
        return [config]

    combined_configs = []
    # convert list/grid dictionary into flat dictionary, where the key is a tuple of the keys and the
    # value is the list of values
    tuple_dict = flatten_dict_to_tuple_keys(config[key])
    _param_names = ['.'.join(t) for t in tuple_dict]

    param_lengths = map(len, tuple_dict.values())

    # create a new config for each parameter setting
    for values in iter_func(*tuple_dict.values()):
        _config = copy.deepcopy(config)

        # Remove Grid/List Argument
        del _config[key]

        # Expand Grid/List Parameters
        for i, t in enumerate(tuple_dict.keys()):
            insert_deep_dictionary(d=_config, t=t, value=values[i])

        _config = extend_config_name(_config, _param_names, values)
        combined_configs.append(_config)
    return combined_configs


def ablative_expand(conf_list):
    combined_configs = []
    for config in conf_list:
        tuple_dict = flatten_dict_to_tuple_keys(config['ablative'])
        _param_names = ['.'.join(t) for t in tuple_dict]

        for i, key in enumerate(tuple_dict):
            for val in tuple_dict[key]:
                _config = copy.deepcopy(config)

                insert_deep_dictionary(
                    _config, key, val
                )

                _config = extend_config_name(_config, [_param_names[i]], [val])
                combined_configs.append(_config)
    return combined_configs


def flatten_dict_to_tuple_keys(d: MutableMapping):
    flat_dict = {}
    for k, v in d.items():
        if isinstance(v, MutableMapping):
            sub_dict = flatten_dict_to_tuple_keys(v)
            flat_dict.update({(k, *sk): sv for sk, sv in sub_dict.items()})

        elif isinstance(v, MutableSequence):
            flat_dict[(k,)] = v

    return flat_dict


def insert_deep_dictionary(d: MutableMapping, t: tuple, value):
    if type(t) is tuple:
        if len(t) == 1:  # tuple contains only one key
            d[t[0]] = value
        else:  # tuple contains more than one key
            if t[0] not in d:
                d[t[0]] = dict()
            insert_deep_dictionary(d[t[0]], t[1:], value)
    else:
        d[t] = value


def append_deep_dictionary(d: MutableMapping, t: tuple, value):
    if type(t) is tuple:
        if len(t) == 1:  # tuple contains only one key
            if t[0] not in d:
                d[t[0]] = []
            d[t[0]].append(value)
        else:  # tuple contains more than one key
            if t[0] not in d:
                d[t[0]] = dict()
            append_deep_dictionary(d[t[0]], t[1:], value)
    else:
        d[t] = value


def extend_config_name(config: dict, param_names: list, values: list) -> dict:
    _converted_name = convert_param_names(param_names, values)

    config['delta_desc'] = config['delta_desc'] + '_' + _converted_name if 'delta_desc' in config else _converted_name
    return config


def convert_param_names(_param_names: list, values: list) -> str:
    _converted_name = '_'.join("{}{}".format(
        shorten_param(k), v) for k, v in zip(_param_names, values))
    # _converted_name = re.sub("[' \[\],()]", '', _converted_name)
    _converted_name = re.sub("[' ]", '', _converted_name)
    _converted_name = re.sub('["]', '', _converted_name)
    _converted_name = re.sub("[(\[]", '_', _converted_name)
    _converted_name = re.sub("[)\]]", '', _converted_name)
    _converted_name = re.sub("[,]", '_', _converted_name)
    return _converted_name


def shorten_param(_param_name):
    name_parts = _param_name.split('.')
    shortened_parts = '.'.join(map(lambda s: s[:3], name_parts[:-1]))
    shortened_leaf = ''.join(map(lambda s: s[0], name_parts[-1].split('_')))
    if shortened_parts:
        return shortened_parts + '.' + shortened_leaf
    else:
        return shortened_leaf


class Slate_Runner():
    def __init__(self, slate, config):
        self.slate = slate
        self.config = config

    def setup(self):
        pass

    def run(self, run):
        pass


class Print_Config_Runner(Slate_Runner):
    def run(self, run):
        slate, config = self.slate, self.config

        pprint(config)
        print('---')
        pprint(slate.consume(config, '', expand=True))
        for k in list(config.keys()):
            del config[k]


class Void_Runner(Slate_Runner):
    def run(self, run):
        slate, config = self.slate, self.config
        for k in list(config.keys()):
            del config[k]


class PDB_Runner(Slate_Runner):
    def run(self, run):
        d()


if __name__ == '__main__':
    raise Exception('You are using it wrong...')