74 lines
3.4 KiB
Python
74 lines
3.4 KiB
Python
from typing import Tuple, Union
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from fancy_gym.black_box.factory import controller_factory
|
|
|
|
|
|
@pytest.mark.parametrize('ctrl_type', controller_factory.ALL_TYPES)
|
|
def test_initialization(ctrl_type: str):
|
|
controller_factory.get_controller(ctrl_type)
|
|
|
|
|
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
def test_velocity(position: np.ndarray, velocity: np.ndarray):
|
|
ctrl = controller_factory.get_controller('velocity')
|
|
a = ctrl(position, velocity, None, None)
|
|
assert np.array_equal(a, velocity)
|
|
|
|
|
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
def test_position(position: np.ndarray, velocity: np.ndarray):
|
|
ctrl = controller_factory.get_controller('position')
|
|
a = ctrl(position, velocity, None, None)
|
|
assert np.array_equal(a, position)
|
|
|
|
|
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('current_velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('p_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('d_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
def test_pd(position: np.ndarray, velocity: np.ndarray, current_position: np.ndarray, current_velocity: np.ndarray,
|
|
p_gains: Union[float, Tuple], d_gains: Union[float, Tuple]):
|
|
ctrl = controller_factory.get_controller('motor', p_gains=p_gains, d_gains=d_gains)
|
|
assert np.array_equal(ctrl.p_gains, p_gains)
|
|
assert np.array_equal(ctrl.d_gains, d_gains)
|
|
|
|
a = ctrl(position, velocity, current_position, current_velocity)
|
|
pd = p_gains * (position - current_position) + d_gains * (velocity - current_velocity)
|
|
assert np.array_equal(a, pd)
|
|
|
|
|
|
@pytest.mark.parametrize('pos_vel', [(np.ones(3, ), np.ones(4, )),
|
|
(np.ones(4, ), np.ones(3, )),
|
|
(np.ones(4, ), np.ones(4, ))])
|
|
def test_pd_invalid_shapes(pos_vel: Tuple[np.ndarray, np.ndarray]):
|
|
position, velocity = pos_vel
|
|
ctrl = controller_factory.get_controller('motor')
|
|
with pytest.raises(ValueError):
|
|
ctrl(position, velocity, np.ones(3, ), np.ones(3, ))
|
|
|
|
|
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
|
@pytest.mark.parametrize('gripper_pos', [0, 1, 0.5])
|
|
def test_metaworld(position: np.ndarray, current_position: np.ndarray, gripper_pos: float):
|
|
ctrl = controller_factory.get_controller('metaworld')
|
|
|
|
position_grip = np.append(position, gripper_pos)
|
|
c_position_grip = np.append(current_position, -1)
|
|
a = ctrl(position_grip, None, c_position_grip, None)
|
|
assert a[-1] == gripper_pos
|
|
assert np.array_equal(a[:-1], position - current_position)
|
|
|
|
|
|
def test_metaworld_invalid_shapes():
|
|
ctrl = controller_factory.get_controller('metaworld')
|
|
with pytest.raises(ValueError):
|
|
ctrl(np.ones(4, ), None, np.ones(3, ), None)
|