fancy_gym/test/test_controller.py
2022-09-30 15:07:48 +02:00

74 lines
3.4 KiB
Python

from typing import Tuple, Union
import numpy as np
import pytest
from fancy_gym.black_box.factory import controller_factory
@pytest.mark.parametrize('ctrl_type', controller_factory.ALL_TYPES)
def test_initialization(ctrl_type: str):
controller_factory.get_controller(ctrl_type)
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
def test_velocity(position: np.ndarray, velocity: np.ndarray):
ctrl = controller_factory.get_controller('velocity')
a = ctrl(position, velocity, None, None)
assert np.array_equal(a, velocity)
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
def test_position(position: np.ndarray, velocity: np.ndarray):
ctrl = controller_factory.get_controller('position')
a = ctrl(position, velocity, None, None)
assert np.array_equal(a, position)
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('current_velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('p_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('d_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
def test_pd(position: np.ndarray, velocity: np.ndarray, current_position: np.ndarray, current_velocity: np.ndarray,
p_gains: Union[float, Tuple], d_gains: Union[float, Tuple]):
ctrl = controller_factory.get_controller('motor', p_gains=p_gains, d_gains=d_gains)
assert np.array_equal(ctrl.p_gains, p_gains)
assert np.array_equal(ctrl.d_gains, d_gains)
a = ctrl(position, velocity, current_position, current_velocity)
pd = p_gains * (position - current_position) + d_gains * (velocity - current_velocity)
assert np.array_equal(a, pd)
@pytest.mark.parametrize('pos_vel', [(np.ones(3, ), np.ones(4, )),
(np.ones(4, ), np.ones(3, )),
(np.ones(4, ), np.ones(4, ))])
def test_pd_invalid_shapes(pos_vel: Tuple[np.ndarray, np.ndarray]):
position, velocity = pos_vel
ctrl = controller_factory.get_controller('motor')
with pytest.raises(ValueError):
ctrl(position, velocity, np.ones(3, ), np.ones(3, ))
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
@pytest.mark.parametrize('gripper_pos', [0, 1, 0.5])
def test_metaworld(position: np.ndarray, current_position: np.ndarray, gripper_pos: float):
ctrl = controller_factory.get_controller('metaworld')
position_grip = np.append(position, gripper_pos)
c_position_grip = np.append(current_position, -1)
a = ctrl(position_grip, None, c_position_grip, None)
assert a[-1] == gripper_pos
assert np.array_equal(a[:-1], position - current_position)
def test_metaworld_invalid_shapes():
ctrl = controller_factory.get_controller('metaworld')
with pytest.raises(ValueError):
ctrl(np.ones(4, ), None, np.ones(3, ), None)