fancy_gym/alr_envs/mp/mp_factory.py

22 lines
909 B
Python
Raw Normal View History

2022-05-03 19:51:54 +02:00
from mp_pytorch.mp.dmp import DMP
from mp_pytorch.mp.promp import ProMP
from mp_pytorch.mp.idmp import IDMP
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
ALL_TYPES = ["promp", "dmp", "idmp"]
2022-06-29 09:37:18 +02:00
def get_trajectory_generator(
trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
2022-05-03 19:51:54 +02:00
):
2022-06-29 09:37:18 +02:00
trajectory_generator_type = trajectory_generator_type.lower()
if trajectory_generator_type == "promp":
2022-05-03 19:51:54 +02:00
return ProMP(basis_generator, action_dim, **kwargs)
2022-06-29 09:37:18 +02:00
elif trajectory_generator_type == "dmp":
2022-05-03 19:51:54 +02:00
return DMP(basis_generator, action_dim, **kwargs)
2022-06-29 09:37:18 +02:00
elif trajectory_generator_type == 'idmp':
2022-05-03 19:51:54 +02:00
return IDMP(basis_generator, action_dim, **kwargs)
else:
2022-06-29 09:37:18 +02:00
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
2022-05-03 19:51:54 +02:00
f"please choose one of {ALL_TYPES}.")