fancy_gym/alr_envs/mp/basis_generator_factory.py

18 lines
885 B
Python
Raw Normal View History

2022-05-03 19:51:54 +02:00
from mp_pytorch import PhaseGenerator, NormalizedRBFBasisGenerator, ZeroStartNormalizedRBFBasisGenerator
from mp_pytorch.basis_gn.rhytmic_basis import RhythmicBasisGenerator
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerator, **kwargs):
basis_generator_type = basis_generator_type.lower()
if basis_generator_type == "rbf":
return NormalizedRBFBasisGenerator(phase_generator, **kwargs)
elif basis_generator_type == "zero_rbf":
return ZeroStartNormalizedRBFBasisGenerator(phase_generator, **kwargs)
elif basis_generator_type == "rhythmic":
return RhythmicBasisGenerator(phase_generator, **kwargs)
else:
raise ValueError(f"Specified basis generator type {basis_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.")