diff --git a/fancy_gym/black_box/controller/base_controller.py b/fancy_gym/black_box/controller/base_controller.py index 1ac1522..e9045aa 100644 --- a/fancy_gym/black_box/controller/base_controller.py +++ b/fancy_gym/black_box/controller/base_controller.py @@ -2,3 +2,6 @@ class BaseController: def get_action(self, des_pos, des_vel, c_pos, c_vel): raise NotImplementedError + + def __call__(self, des_pos, des_vel, c_pos, c_vel): + return self.get_action(des_pos, des_vel, c_pos, c_vel) diff --git a/fancy_gym/black_box/controller/meta_world_controller.py b/fancy_gym/black_box/controller/meta_world_controller.py index efd8983..3e5bd37 100644 --- a/fancy_gym/black_box/controller/meta_world_controller.py +++ b/fancy_gym/black_box/controller/meta_world_controller.py @@ -18,7 +18,8 @@ class MetaWorldController(BaseController): cur_pos = c_pos[:-1] xyz_pos = des_pos[:-1] - assert xyz_pos.shape == cur_pos.shape, \ - f"Mismatch in dimension between desired position {xyz_pos.shape} and current position {cur_pos.shape}" + if xyz_pos.shape != cur_pos.shape: + raise ValueError(f"Mismatch in dimension between desired position" + f" {xyz_pos.shape} and current position {cur_pos.shape}") trq = np.hstack([(xyz_pos - cur_pos), gripper_pos]) return trq diff --git a/fancy_gym/black_box/controller/pd_controller.py b/fancy_gym/black_box/controller/pd_controller.py index 35203d8..78c2adc 100644 --- a/fancy_gym/black_box/controller/pd_controller.py +++ b/fancy_gym/black_box/controller/pd_controller.py @@ -8,7 +8,6 @@ class PDController(BaseController): A PD-Controller. Using position and velocity information from a provided environment, the tracking_controller calculates a response based on the desired position and velocity - :param env: A position environment :param p_gains: Factors for the proportional gains :param d_gains: Factors for the differential gains """ @@ -20,9 +19,11 @@ class PDController(BaseController): self.d_gains = d_gains def get_action(self, des_pos, des_vel, c_pos, c_vel): - assert des_pos.shape == c_pos.shape, \ - f"Mismatch in dimension between desired position {des_pos.shape} and current position {c_pos.shape}" - assert des_vel.shape == c_vel.shape, \ - f"Mismatch in dimension between desired velocity {des_vel.shape} and current velocity {c_vel.shape}" + if des_pos.shape != c_pos.shape: + raise ValueError(f"Mismatch in dimension between desired position " + f"{des_pos.shape} and current position {c_pos.shape}") + if des_vel.shape != c_vel.shape: + raise ValueError(f"Mismatch in dimension between desired velocity" + f" {des_vel.shape} and current velocity {c_vel.shape}") trq = self.p_gains * (des_pos - c_pos) + self.d_gains * (des_vel - c_vel) return trq diff --git a/fancy_gym/black_box/factory/basis_generator_factory.py b/fancy_gym/black_box/factory/basis_generator_factory.py index 30d2a10..53e1940 100644 --- a/fancy_gym/black_box/factory/basis_generator_factory.py +++ b/fancy_gym/black_box/factory/basis_generator_factory.py @@ -1,4 +1,5 @@ -from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, ProDMPBasisGenerator +from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, \ + ProDMPBasisGenerator from mp_pytorch.phase_gn import PhaseGenerator ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"] @@ -11,6 +12,8 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat elif basis_generator_type == "zero_rbf": return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs) elif basis_generator_type == "prodmp": + from mp_pytorch.phase_gn import ExpDecayPhaseGenerator + assert isinstance(phase_generator, ExpDecayPhaseGenerator) return ProDMPBasisGenerator(phase_generator, **kwargs) elif basis_generator_type == "rhythmic": raise NotImplementedError() diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 8e1aa3d..a12f057 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -248,7 +248,6 @@ register( max_episode_steps=FIXED_RELEASE_STEP, ) - # movement Primitive Environments ## Simple Reacher diff --git a/fancy_gym/meta/README.MD b/fancy_gym/meta/README.MD index c8d9cd1..1664cb0 100644 --- a/fancy_gym/meta/README.MD +++ b/fancy_gym/meta/README.MD @@ -1,6 +1,6 @@ # MetaWorld Wrappers -These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Motion Primitive gym interface with them. +These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Movement Primitive gym interface with them. All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized. Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following: ```python diff --git a/fancy_gym/meta/__init__.py b/fancy_gym/meta/__init__.py index 104d4e7..4fb23b2 100644 --- a/fancy_gym/meta/__init__.py +++ b/fancy_gym/meta/__init__.py @@ -28,11 +28,31 @@ DEFAULT_BB_DICT_ProMP = { } } +DEFAULT_BB_DICT_ProDMP = { + "name": 'EnvName', + "wrappers": [], + "trajectory_generator_kwargs": { + 'trajectory_generator_type': 'prodmp' + }, + "phase_generator_kwargs": { + 'phase_generator_type': 'exp' + }, + "controller_kwargs": { + 'controller_type': 'metaworld', + }, + "basis_generator_kwargs": { + 'basis_generator_type': 'prodmp', + 'num_basis': 5 + } +} + _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2", "plate-slide-side-v2", "plate-slide-back-side-v2"] for _task in _goal_change_envs: task_id_split = _task.split("-") name = "".join([s.capitalize() for s in task_id_split[:-1]]) + + # ProMP _env_id = f'{name}ProMP-{task_id_split[-1]}' kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper) @@ -45,10 +65,25 @@ for _task in _goal_change_envs: ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + # ProDMP + _env_id = f'{name}ProDMP-{task_id_split[-1]}' + kwargs_dict_goal_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) + kwargs_dict_goal_change_prodmp['wrappers'].append(goal_change_mp_wrapper.MPWrapper) + kwargs_dict_goal_change_prodmp['name'] = f'metaworld:{_task}' + + register( + id=_env_id, + entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', + kwargs=kwargs_dict_goal_change_prodmp + ) + ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) + _object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"] for _task in _object_change_envs: task_id_split = _task.split("-") name = "".join([s.capitalize() for s in task_id_split[:-1]]) + + # ProMP _env_id = f'{name}ProMP-{task_id_split[-1]}' kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper) @@ -60,6 +95,18 @@ for _task in _object_change_envs: ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + # ProDMP + _env_id = f'{name}ProDMP-{task_id_split[-1]}' + kwargs_dict_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) + kwargs_dict_object_change_prodmp['wrappers'].append(object_change_mp_wrapper.MPWrapper) + kwargs_dict_object_change_prodmp['name'] = f'metaworld:{_task}' + register( + id=_env_id, + entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', + kwargs=kwargs_dict_object_change_prodmp + ) + ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) + _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2", "button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2", "coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2", @@ -74,6 +121,8 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press for _task in _goal_and_object_change_envs: task_id_split = _task.split("-") name = "".join([s.capitalize() for s in task_id_split[:-1]]) + + # ProMP _env_id = f'{name}ProMP-{task_id_split[-1]}' kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper) @@ -86,10 +135,26 @@ for _task in _goal_and_object_change_envs: ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + # ProDMP + _env_id = f'{name}ProDMP-{task_id_split[-1]}' + kwargs_dict_goal_and_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) + kwargs_dict_goal_and_object_change_prodmp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper) + kwargs_dict_goal_and_object_change_prodmp['name'] = f'metaworld:{_task}' + + register( + id=_env_id, + entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', + kwargs=kwargs_dict_goal_and_object_change_prodmp + ) + ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) + + _goal_and_endeffector_change_envs = ["basketball-v2"] for _task in _goal_and_endeffector_change_envs: task_id_split = _task.split("-") name = "".join([s.capitalize() for s in task_id_split[:-1]]) + + # ProMP _env_id = f'{name}ProMP-{task_id_split[-1]}' kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper) @@ -101,3 +166,16 @@ for _task in _goal_and_endeffector_change_envs: kwargs=kwargs_dict_goal_and_endeffector_change_promp ) ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + + # ProDMP + _env_id = f'{name}ProDMP-{task_id_split[-1]}' + kwargs_dict_goal_and_endeffector_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP) + kwargs_dict_goal_and_endeffector_change_prodmp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper) + kwargs_dict_goal_and_endeffector_change_prodmp['name'] = f'metaworld:{_task}' + + register( + id=_env_id, + entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', + kwargs=kwargs_dict_goal_and_endeffector_change_prodmp + ) + ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id) diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 5221423..18ab6ed 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -141,7 +141,7 @@ def make_bb( Returns: DMP wrapped gym env """ - _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None)) + _verify_time_limit(traj_gen_kwargs.get("duration"), kwargs.get("time_limit")) learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories') do_replanning = black_box_kwargs.get('replanning_schedule') diff --git a/setup.py b/setup.py index 1148e85..c029591 100644 --- a/setup.py +++ b/setup.py @@ -18,14 +18,19 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[ setup( author='Fabian Otto, Onur Celik', name='fancy_gym', - version='0.2', + version='0.3', classifiers=[ - # Python 3.7 is minimally supported - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", + 'Development Status :: 4 - Beta', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', ], extras_require=extras, install_requires=[ @@ -40,7 +45,7 @@ setup( }, python_requires=">=3.7", url='https://github.com/ALRhub/fancy_gym/', - # license='AGPL-3.0 license', + license='MIT license', author_email='', description='Fancy Gym: Unifying interface for various RL benchmarks with support for Black Box approaches.' ) diff --git a/test/test_black_box.py b/test/test_black_box.py new file mode 100644 index 0000000..fb7f78c --- /dev/null +++ b/test/test_black_box.py @@ -0,0 +1,215 @@ +from itertools import chain +from typing import Tuple, Type, Union, Optional + +import gym +import numpy as np +import pytest +from gym import register +from gym.core import ActType, ObsType + +import fancy_gym +from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper +from fancy_gym.utils.time_aware_observation import TimeAwareObservation + +SEED = 1 +ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch', 'metaworld:reach-v2', 'Reacher-v2'] +WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper, + fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper] +ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) + + +class Object(object): + pass + + +class ToyEnv(gym.Env): + observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64) + action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64) + dt = 0.01 + + def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()): + self.a, self.b, self.c, self.d, self.e = a, b, c, d, e + + def reset(self, *, seed: Optional[int] = None, return_info: bool = False, + options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]: + return np.array([-1]) + + def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + return np.array([-1]), 1, False, {} + + def render(self, mode="human"): + pass + + +class ToyWrapper(RawInterfaceWrapper): + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return np.ones(self.action_space.shape) + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return np.ones(self.action_space.shape) + + +@pytest.fixture(scope="session", autouse=True) +def setup(): + register( + id=f'toy-v0', + entry_point='test.test_black_box:ToyEnv', + max_episode_steps=50, + ) + + +@pytest.mark.parametrize('env_id', ENV_IDS) +def test_missing_wrapper(env_id: str): + with pytest.raises(ValueError): + fancy_gym.make_bb(env_id, [], {}, {}, {}, {}, {}) + + +def test_missing_local_state(): + env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + env.reset() + with pytest.raises(NotImplementedError): + env.step(env.action_space.sample()) + + +@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS)) +@pytest.mark.parametrize('verbose', [1, 2]) +def test_verbosity(env_wrap: Tuple[str, Type[RawInterfaceWrapper]], verbose: int): + env_id, wrapper_class = env_wrap + env = fancy_gym.make_bb(env_id, [wrapper_class], {}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + env.reset() + info_keys = env.step(env.action_space.sample())[3].keys() + + env_step = fancy_gym.make(env_id, SEED) + env_step.reset() + info_keys_step = env_step.step(env_step.action_space.sample())[3].keys() + + assert info_keys_step in info_keys + assert 'trajectory_length' in info_keys + + if verbose >= 2: + mp_keys = ['position', 'velocities', 'step_actions', 'step_observations', 'step_rewards'] + assert mp_keys in info_keys + + +@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS)) +def test_length(env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): + env_id, wrapper_class = env_wrap + env = fancy_gym.make_bb(env_id, [wrapper_class], {}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + env.reset() + length = env.step(env.action_space.sample())[3]['trajectory_length'] + + assert length == env.spec.max_episode_steps + + +@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])]) +def test_aggregation(reward_aggregation: callable): + env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + env.reset() + # ToyEnv only returns 1 as reward + assert env.step(env.action_space.sample())[1] == reward_aggregation(np.ones(50, )) + + +@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS)) +def test_context_space(env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): + env_id, wrapper_class = env_wrap + env = fancy_gym.make_bb(env_id, [wrapper_class], {}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + # check if observation space matches with the specified mask values which are true + env_step = fancy_gym.make(env_id, SEED) + wrapper = wrapper_class(env_step) + assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape + + +@pytest.mark.parametrize('num_dof', [0, 1, 2, 5]) +@pytest.mark.parametrize('num_basis', [0, 1, 2, 5]) +@pytest.mark.parametrize('learn_tau', [True, False]) +@pytest.mark.parametrize('learn_delay', [True, False]) +def test_action_space(num_dof: int, num_basis: int, learn_tau: bool, learn_delay: bool): + env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {}, + {'trajectory_generator_type': 'promp', + 'action_dim': num_dof + }, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear', + 'learn_tau': learn_tau, + 'learn_delay': learn_delay + }, + {'basis_generator_type': 'rbf', + 'num_basis': num_basis + }) + assert env.action_space.shape[0] == num_dof * num_basis + int(learn_tau) + int(learn_delay) + + +@pytest.mark.parametrize('a', [1]) +@pytest.mark.parametrize('b', [1.0]) +@pytest.mark.parametrize('c', [[1], [1.0], ['str'], [{'a': 'b'}], [np.ones(3, )]]) +@pytest.mark.parametrize('d', [{'a': 1}, {1: 2.0}, {'a': [1.0]}, {'a': np.ones(3, )}, {'a': {'a': 'b'}}]) +@pytest.mark.parametrize('e', [Object()]) +def test_change_env_kwargs(a: int, b: float, c: list, d: dict, e: Object): + env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}, + a=a, b=b, c=c, d=d, e=e + ) + assert a is env.a + assert b is env.b + assert c is env.c + # Due to how gym works dict kwargs need to be copied and hence can only be checked to have the same content + assert d == env.d + assert e is env.e + + +@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS)) +@pytest.mark.parametrize('add_time_aware_wrapper_before', [True, False]) +def test_learn_sub_trajectories(env_wrap: Tuple[str, Type[RawInterfaceWrapper]], add_time_aware_wrapper_before: bool): + env_id, wrapper_class = env_wrap + env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED)) + wrappers = [wrapper_class] + + # has time aware wrapper + if add_time_aware_wrapper_before: + wrappers += [TimeAwareObservation] + + env = fancy_gym.make_bb(env_id, [wrapper_class], {'learn_sub_trajectories': True}, + {'trajectory_generator_type': 'promp'}, + {'controller_type': 'motor'}, + {'phase_generator_type': 'linear'}, + {'basis_generator_type': 'rbf'}) + + assert env.learn_sub_trajectories + assert env.traj_gen.learn_tau + assert env.observation_space == env_step.observation_space + + env.reset() + action = env.action_space.sample() + obs, r, d, info = env.step(action) + + length = info['trajectory_length'] + + factor = 1 / env.dt + assert np.allclose(length * env.dt, np.round(factor * action[0]) / factor) + assert np.allclose(length * env.dt, np.round(factor * env.traj_gen.tau.numpy()) / factor) diff --git a/test/test_controller.py b/test/test_controller.py new file mode 100644 index 0000000..c530c50 --- /dev/null +++ b/test/test_controller.py @@ -0,0 +1,73 @@ +from typing import Tuple, Union + +import numpy as np +import pytest + +from fancy_gym.black_box.factory import controller_factory + + +@pytest.mark.parametrize('ctrl_type', controller_factory.ALL_TYPES) +def test_initialization(ctrl_type: str): + controller_factory.get_controller(ctrl_type) + + +@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +def test_velocity(position: np.ndarray, velocity: np.ndarray): + ctrl = controller_factory.get_controller('velocity') + a = ctrl(position, velocity, None, None) + assert np.array_equal(a, velocity) + + +@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +def test_position(position: np.ndarray, velocity: np.ndarray): + ctrl = controller_factory.get_controller('position') + a = ctrl(position, velocity, None, None) + assert np.array_equal(a, position) + + +@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('current_velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('p_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('d_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +def test_pd(position: np.ndarray, velocity: np.ndarray, current_position: np.ndarray, current_velocity: np.ndarray, + p_gains: Union[float, Tuple], d_gains: Union[float, Tuple]): + ctrl = controller_factory.get_controller('motor', p_gains=p_gains, d_gains=d_gains) + assert np.array_equal(ctrl.p_gains, p_gains) + assert np.array_equal(ctrl.d_gains, d_gains) + + a = ctrl(position, velocity, current_position, current_velocity) + pd = p_gains * (position - current_position) + d_gains * (velocity - current_velocity) + assert np.array_equal(a, pd) + + +@pytest.mark.parametrize('pos_vel', [(np.ones(3, ), np.ones(4, )), + (np.ones(4, ), np.ones(3, )), + (np.ones(4, ), np.ones(4, ))]) +def test_pd_invalid_shapes(pos_vel: Tuple[np.ndarray, np.ndarray]): + position, velocity = pos_vel + ctrl = controller_factory.get_controller('motor') + with pytest.raises(ValueError): + ctrl(position, velocity, np.ones(3, ), np.ones(3, )) + + +@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)]) +@pytest.mark.parametrize('gripper_pos', [0, 1, 0.5]) +def test_metaworld(position: np.ndarray, current_position: np.ndarray, gripper_pos: float): + ctrl = controller_factory.get_controller('metaworld') + + position_grip = np.append(position, gripper_pos) + c_position_grip = np.append(current_position, -1) + a = ctrl(position_grip, None, c_position_grip, None) + assert a[-1] == gripper_pos + assert np.array_equal(a[:-1], position - current_position) + + +def test_metaworld_invalid_shapes(): + ctrl = controller_factory.get_controller('metaworld') + with pytest.raises(ValueError): + ctrl(np.ones(4, ), None, np.ones(3, ), None)