from mp_pytorch import LinearPhaseGenerator, ExpDecayPhaseGenerator
from mp_pytorch.phase_gn.rhythmic_phase_generator import RhythmicPhaseGenerator
from mp_pytorch.phase_gn.smooth_phase_generator import SmoothPhaseGenerator

ALL_TYPES = ["linear", "exp", "rhythmic", "smooth"]


def get_phase_generator(phase_generator_type, **kwargs):
    phase_generator_type = phase_generator_type.lower()
    if phase_generator_type == "linear":
        return LinearPhaseGenerator(**kwargs)
    elif phase_generator_type == "exp":
        return ExpDecayPhaseGenerator(**kwargs)
    elif phase_generator_type == "rhythmic":
        return RhythmicPhaseGenerator(**kwargs)
    elif phase_generator_type == "smooth":
        return SmoothPhaseGenerator(**kwargs)
    else:
        raise ValueError(f"Specified phase generator type {phase_generator_type} not supported, "
                         f"please choose one of {ALL_TYPES}.")