fancy_gym/alr_envs/mp/mp_factory.py
2022-06-29 09:37:18 +02:00

22 lines
909 B
Python

from mp_pytorch.mp.dmp import DMP
from mp_pytorch.mp.promp import ProMP
from mp_pytorch.mp.idmp import IDMP
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
ALL_TYPES = ["promp", "dmp", "idmp"]
def get_trajectory_generator(
trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
):
trajectory_generator_type = trajectory_generator_type.lower()
if trajectory_generator_type == "promp":
return ProMP(basis_generator, action_dim, **kwargs)
elif trajectory_generator_type == "dmp":
return DMP(basis_generator, action_dim, **kwargs)
elif trajectory_generator_type == 'idmp':
return IDMP(basis_generator, action_dim, **kwargs)
else:
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.")