2022-05-03 19:51:54 +02:00
|
|
|
from mp_pytorch.mp.dmp import DMP
|
|
|
|
from mp_pytorch.mp.promp import ProMP
|
|
|
|
from mp_pytorch.mp.idmp import IDMP
|
|
|
|
|
|
|
|
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
|
|
|
|
|
|
|
ALL_TYPES = ["promp", "dmp", "idmp"]
|
|
|
|
|
|
|
|
|
2022-06-29 09:37:18 +02:00
|
|
|
def get_trajectory_generator(
|
|
|
|
trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
2022-05-03 19:51:54 +02:00
|
|
|
):
|
2022-06-29 09:37:18 +02:00
|
|
|
trajectory_generator_type = trajectory_generator_type.lower()
|
|
|
|
if trajectory_generator_type == "promp":
|
2022-05-03 19:51:54 +02:00
|
|
|
return ProMP(basis_generator, action_dim, **kwargs)
|
2022-06-29 09:37:18 +02:00
|
|
|
elif trajectory_generator_type == "dmp":
|
2022-05-03 19:51:54 +02:00
|
|
|
return DMP(basis_generator, action_dim, **kwargs)
|
2022-06-29 09:37:18 +02:00
|
|
|
elif trajectory_generator_type == 'idmp':
|
2022-05-03 19:51:54 +02:00
|
|
|
return IDMP(basis_generator, action_dim, **kwargs)
|
|
|
|
else:
|
2022-06-29 09:37:18 +02:00
|
|
|
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
2022-05-03 19:51:54 +02:00
|
|
|
f"please choose one of {ALL_TYPES}.")
|