22 lines
900 B
Python
22 lines
900 B
Python
|
from mp_pytorch.mp.dmp import DMP
|
||
|
from mp_pytorch.mp.promp import ProMP
|
||
|
from mp_pytorch.mp.idmp import IDMP
|
||
|
|
||
|
from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
||
|
|
||
|
ALL_TYPES = ["promp", "dmp", "idmp"]
|
||
|
|
||
|
|
||
|
def get_movement_primitive(
|
||
|
movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
||
|
):
|
||
|
movement_primitives_type = movement_primitives_type.lower()
|
||
|
if movement_primitives_type == "promp":
|
||
|
return ProMP(basis_generator, action_dim, **kwargs)
|
||
|
elif movement_primitives_type == "dmp":
|
||
|
return DMP(basis_generator, action_dim, **kwargs)
|
||
|
elif movement_primitives_type == 'idmp':
|
||
|
return IDMP(basis_generator, action_dim, **kwargs)
|
||
|
else:
|
||
|
raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, "
|
||
|
f"please choose one of {ALL_TYPES}.")
|