fancy_gym/alr_envs/mp/mp_factory.py
2022-05-03 19:51:54 +02:00

22 lines
900 B
Python

from mp_pytorch.mp.dmp import DMP
from mp_pytorch.mp.promp import ProMP
from mp_pytorch.mp.idmp import IDMP
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
ALL_TYPES = ["promp", "dmp", "idmp"]
def get_movement_primitive(
movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
):
movement_primitives_type = movement_primitives_type.lower()
if movement_primitives_type == "promp":
return ProMP(basis_generator, action_dim, **kwargs)
elif movement_primitives_type == "dmp":
return DMP(basis_generator, action_dim, **kwargs)
elif movement_primitives_type == 'idmp':
return IDMP(basis_generator, action_dim, **kwargs)
else:
raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, "
f"please choose one of {ALL_TYPES}.")