Merge remote-tracking branch 'origin/Add-ProDMP-envs' into mujoco_binding
# Conflicts: # fancy_gym/black_box/factory/basis_generator_factory.py
This commit is contained in:
commit
f16a128d57
@ -2,3 +2,6 @@ class BaseController:
|
|||||||
|
|
||||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __call__(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
|
return self.get_action(des_pos, des_vel, c_pos, c_vel)
|
||||||
|
@ -18,7 +18,8 @@ class MetaWorldController(BaseController):
|
|||||||
cur_pos = c_pos[:-1]
|
cur_pos = c_pos[:-1]
|
||||||
xyz_pos = des_pos[:-1]
|
xyz_pos = des_pos[:-1]
|
||||||
|
|
||||||
assert xyz_pos.shape == cur_pos.shape, \
|
if xyz_pos.shape != cur_pos.shape:
|
||||||
f"Mismatch in dimension between desired position {xyz_pos.shape} and current position {cur_pos.shape}"
|
raise ValueError(f"Mismatch in dimension between desired position"
|
||||||
|
f" {xyz_pos.shape} and current position {cur_pos.shape}")
|
||||||
trq = np.hstack([(xyz_pos - cur_pos), gripper_pos])
|
trq = np.hstack([(xyz_pos - cur_pos), gripper_pos])
|
||||||
return trq
|
return trq
|
||||||
|
@ -8,7 +8,6 @@ class PDController(BaseController):
|
|||||||
A PD-Controller. Using position and velocity information from a provided environment,
|
A PD-Controller. Using position and velocity information from a provided environment,
|
||||||
the tracking_controller calculates a response based on the desired position and velocity
|
the tracking_controller calculates a response based on the desired position and velocity
|
||||||
|
|
||||||
:param env: A position environment
|
|
||||||
:param p_gains: Factors for the proportional gains
|
:param p_gains: Factors for the proportional gains
|
||||||
:param d_gains: Factors for the differential gains
|
:param d_gains: Factors for the differential gains
|
||||||
"""
|
"""
|
||||||
@ -20,9 +19,11 @@ class PDController(BaseController):
|
|||||||
self.d_gains = d_gains
|
self.d_gains = d_gains
|
||||||
|
|
||||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
assert des_pos.shape == c_pos.shape, \
|
if des_pos.shape != c_pos.shape:
|
||||||
f"Mismatch in dimension between desired position {des_pos.shape} and current position {c_pos.shape}"
|
raise ValueError(f"Mismatch in dimension between desired position "
|
||||||
assert des_vel.shape == c_vel.shape, \
|
f"{des_pos.shape} and current position {c_pos.shape}")
|
||||||
f"Mismatch in dimension between desired velocity {des_vel.shape} and current velocity {c_vel.shape}"
|
if des_vel.shape != c_vel.shape:
|
||||||
|
raise ValueError(f"Mismatch in dimension between desired velocity"
|
||||||
|
f" {des_vel.shape} and current velocity {c_vel.shape}")
|
||||||
trq = self.p_gains * (des_pos - c_pos) + self.d_gains * (des_vel - c_vel)
|
trq = self.p_gains * (des_pos - c_pos) + self.d_gains * (des_vel - c_vel)
|
||||||
return trq
|
return trq
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, ProDMPBasisGenerator
|
from mp_pytorch.basis_gn import NormalizedRBFBasisGenerator, ZeroPaddingNormalizedRBFBasisGenerator, \
|
||||||
|
ProDMPBasisGenerator
|
||||||
from mp_pytorch.phase_gn import PhaseGenerator
|
from mp_pytorch.phase_gn import PhaseGenerator
|
||||||
|
|
||||||
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
ALL_TYPES = ["rbf", "zero_rbf", "rhythmic"]
|
||||||
@ -11,6 +12,8 @@ def get_basis_generator(basis_generator_type: str, phase_generator: PhaseGenerat
|
|||||||
elif basis_generator_type == "zero_rbf":
|
elif basis_generator_type == "zero_rbf":
|
||||||
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
return ZeroPaddingNormalizedRBFBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "prodmp":
|
elif basis_generator_type == "prodmp":
|
||||||
|
from mp_pytorch.phase_gn import ExpDecayPhaseGenerator
|
||||||
|
assert isinstance(phase_generator, ExpDecayPhaseGenerator)
|
||||||
return ProDMPBasisGenerator(phase_generator, **kwargs)
|
return ProDMPBasisGenerator(phase_generator, **kwargs)
|
||||||
elif basis_generator_type == "rhythmic":
|
elif basis_generator_type == "rhythmic":
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -248,7 +248,6 @@ register(
|
|||||||
max_episode_steps=FIXED_RELEASE_STEP,
|
max_episode_steps=FIXED_RELEASE_STEP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# movement Primitive Environments
|
# movement Primitive Environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# MetaWorld Wrappers
|
# MetaWorld Wrappers
|
||||||
|
|
||||||
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Motion Primitive gym interface with them.
|
These are the Environment Wrappers for selected [Metaworld](https://meta-world.github.io/) environments in order to use our Movement Primitive gym interface with them.
|
||||||
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
All Metaworld environments have a 39 dimensional observation space with the same structure. The tasks differ only in the objective and the initial observations that are randomized.
|
||||||
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
Unused observations are zeroed out. E.g. for `Button-Press-v2` the observation mask looks the following:
|
||||||
```python
|
```python
|
||||||
|
@ -28,11 +28,31 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFAULT_BB_DICT_ProDMP = {
|
||||||
|
"name": 'EnvName',
|
||||||
|
"wrappers": [],
|
||||||
|
"trajectory_generator_kwargs": {
|
||||||
|
'trajectory_generator_type': 'prodmp'
|
||||||
|
},
|
||||||
|
"phase_generator_kwargs": {
|
||||||
|
'phase_generator_type': 'exp'
|
||||||
|
},
|
||||||
|
"controller_kwargs": {
|
||||||
|
'controller_type': 'metaworld',
|
||||||
|
},
|
||||||
|
"basis_generator_kwargs": {
|
||||||
|
'basis_generator_type': 'prodmp',
|
||||||
|
'num_basis': 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
||||||
for _task in _goal_change_envs:
|
for _task in _goal_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||||
@ -45,10 +65,25 @@ for _task in _goal_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_change_prodmp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
|
||||||
for _task in _object_change_envs:
|
for _task in _object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||||
@ -60,6 +95,18 @@ for _task in _object_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_object_change_prodmp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_object_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
|
||||||
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
|
||||||
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
"coffee-push-v2", "dial-turn-v2", "disassemble-v2", "door-close-v2",
|
||||||
@ -74,6 +121,8 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
|
|||||||
for _task in _goal_and_object_change_envs:
|
for _task in _goal_and_object_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||||
@ -86,10 +135,26 @@ for _task in _goal_and_object_change_envs:
|
|||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_object_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_and_object_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
|
||||||
|
|
||||||
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
_goal_and_endeffector_change_envs = ["basketball-v2"]
|
||||||
for _task in _goal_and_endeffector_change_envs:
|
for _task in _goal_and_endeffector_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
|
|
||||||
|
# ProMP
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||||
@ -101,3 +166,16 @@ for _task in _goal_and_endeffector_change_envs:
|
|||||||
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
# ProDMP
|
||||||
|
_env_id = f'{name}ProDMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_prodmp['name'] = f'metaworld:{_task}'
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
kwargs=kwargs_dict_goal_and_endeffector_change_prodmp
|
||||||
|
)
|
||||||
|
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
|
||||||
|
@ -141,7 +141,7 @@ def make_bb(
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
_verify_time_limit(traj_gen_kwargs.get("duration"), kwargs.get("time_limit"))
|
||||||
|
|
||||||
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
|
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
|
||||||
do_replanning = black_box_kwargs.get('replanning_schedule')
|
do_replanning = black_box_kwargs.get('replanning_schedule')
|
||||||
|
21
setup.py
21
setup.py
@ -18,14 +18,19 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[
|
|||||||
setup(
|
setup(
|
||||||
author='Fabian Otto, Onur Celik',
|
author='Fabian Otto, Onur Celik',
|
||||||
name='fancy_gym',
|
name='fancy_gym',
|
||||||
version='0.2',
|
version='0.3',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
# Python 3.7 is minimally supported
|
'Development Status :: 4 - Beta',
|
||||||
"Programming Language :: Python :: 3",
|
'Intended Audience :: Science/Research',
|
||||||
"Programming Language :: Python :: 3.7",
|
'License :: OSI Approved :: MIT License',
|
||||||
"Programming Language :: Python :: 3.8",
|
'Natural Language :: English',
|
||||||
"Programming Language :: Python :: 3.9",
|
'Operating System :: OS Independent',
|
||||||
"Programming Language :: Python :: 3.10",
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.7',
|
||||||
|
'Programming Language :: Python :: 3.8',
|
||||||
|
'Programming Language :: Python :: 3.9',
|
||||||
|
'Programming Language :: Python :: 3.10',
|
||||||
],
|
],
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
@ -40,7 +45,7 @@ setup(
|
|||||||
},
|
},
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.7",
|
||||||
url='https://github.com/ALRhub/fancy_gym/',
|
url='https://github.com/ALRhub/fancy_gym/',
|
||||||
# license='AGPL-3.0 license',
|
license='MIT license',
|
||||||
author_email='',
|
author_email='',
|
||||||
description='Fancy Gym: Unifying interface for various RL benchmarks with support for Black Box approaches.'
|
description='Fancy Gym: Unifying interface for various RL benchmarks with support for Black Box approaches.'
|
||||||
)
|
)
|
||||||
|
215
test/test_black_box.py
Normal file
215
test/test_black_box.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
from itertools import chain
|
||||||
|
from typing import Tuple, Type, Union, Optional
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from gym import register
|
||||||
|
from gym.core import ActType, ObsType
|
||||||
|
|
||||||
|
import fancy_gym
|
||||||
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
||||||
|
|
||||||
|
SEED = 1
|
||||||
|
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch', 'metaworld:reach-v2', 'Reacher-v2']
|
||||||
|
WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper,
|
||||||
|
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
|
||||||
|
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
|
|
||||||
|
|
||||||
|
class Object(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToyEnv(gym.Env):
|
||||||
|
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
||||||
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
||||||
|
dt = 0.01
|
||||||
|
|
||||||
|
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
|
||||||
|
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
|
||||||
|
|
||||||
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||||
|
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
|
return np.array([-1])
|
||||||
|
|
||||||
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
||||||
|
return np.array([-1]), 1, False, {}
|
||||||
|
|
||||||
|
def render(self, mode="human"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ToyWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return np.ones(self.action_space.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return np.ones(self.action_space.shape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def setup():
|
||||||
|
register(
|
||||||
|
id=f'toy-v0',
|
||||||
|
entry_point='test.test_black_box:ToyEnv',
|
||||||
|
max_episode_steps=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('env_id', ENV_IDS)
|
||||||
|
def test_missing_wrapper(env_id: str):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
fancy_gym.make_bb(env_id, [], {}, {}, {}, {}, {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_local_state():
|
||||||
|
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
env.reset()
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
env.step(env.action_space.sample())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||||
|
@pytest.mark.parametrize('verbose', [1, 2])
|
||||||
|
def test_verbosity(env_wrap: Tuple[str, Type[RawInterfaceWrapper]], verbose: int):
|
||||||
|
env_id, wrapper_class = env_wrap
|
||||||
|
env = fancy_gym.make_bb(env_id, [wrapper_class], {},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
env.reset()
|
||||||
|
info_keys = env.step(env.action_space.sample())[3].keys()
|
||||||
|
|
||||||
|
env_step = fancy_gym.make(env_id, SEED)
|
||||||
|
env_step.reset()
|
||||||
|
info_keys_step = env_step.step(env_step.action_space.sample())[3].keys()
|
||||||
|
|
||||||
|
assert info_keys_step in info_keys
|
||||||
|
assert 'trajectory_length' in info_keys
|
||||||
|
|
||||||
|
if verbose >= 2:
|
||||||
|
mp_keys = ['position', 'velocities', 'step_actions', 'step_observations', 'step_rewards']
|
||||||
|
assert mp_keys in info_keys
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||||
|
def test_length(env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||||
|
env_id, wrapper_class = env_wrap
|
||||||
|
env = fancy_gym.make_bb(env_id, [wrapper_class], {},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
env.reset()
|
||||||
|
length = env.step(env.action_space.sample())[3]['trajectory_length']
|
||||||
|
|
||||||
|
assert length == env.spec.max_episode_steps
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
|
||||||
|
def test_aggregation(reward_aggregation: callable):
|
||||||
|
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
env.reset()
|
||||||
|
# ToyEnv only returns 1 as reward
|
||||||
|
assert env.step(env.action_space.sample())[1] == reward_aggregation(np.ones(50, ))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||||
|
def test_context_space(env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||||
|
env_id, wrapper_class = env_wrap
|
||||||
|
env = fancy_gym.make_bb(env_id, [wrapper_class], {},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
# check if observation space matches with the specified mask values which are true
|
||||||
|
env_step = fancy_gym.make(env_id, SEED)
|
||||||
|
wrapper = wrapper_class(env_step)
|
||||||
|
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_dof', [0, 1, 2, 5])
|
||||||
|
@pytest.mark.parametrize('num_basis', [0, 1, 2, 5])
|
||||||
|
@pytest.mark.parametrize('learn_tau', [True, False])
|
||||||
|
@pytest.mark.parametrize('learn_delay', [True, False])
|
||||||
|
def test_action_space(num_dof: int, num_basis: int, learn_tau: bool, learn_delay: bool):
|
||||||
|
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {},
|
||||||
|
{'trajectory_generator_type': 'promp',
|
||||||
|
'action_dim': num_dof
|
||||||
|
},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear',
|
||||||
|
'learn_tau': learn_tau,
|
||||||
|
'learn_delay': learn_delay
|
||||||
|
},
|
||||||
|
{'basis_generator_type': 'rbf',
|
||||||
|
'num_basis': num_basis
|
||||||
|
})
|
||||||
|
assert env.action_space.shape[0] == num_dof * num_basis + int(learn_tau) + int(learn_delay)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('a', [1])
|
||||||
|
@pytest.mark.parametrize('b', [1.0])
|
||||||
|
@pytest.mark.parametrize('c', [[1], [1.0], ['str'], [{'a': 'b'}], [np.ones(3, )]])
|
||||||
|
@pytest.mark.parametrize('d', [{'a': 1}, {1: 2.0}, {'a': [1.0]}, {'a': np.ones(3, )}, {'a': {'a': 'b'}}])
|
||||||
|
@pytest.mark.parametrize('e', [Object()])
|
||||||
|
def test_change_env_kwargs(a: int, b: float, c: list, d: dict, e: Object):
|
||||||
|
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'},
|
||||||
|
a=a, b=b, c=c, d=d, e=e
|
||||||
|
)
|
||||||
|
assert a is env.a
|
||||||
|
assert b is env.b
|
||||||
|
assert c is env.c
|
||||||
|
# Due to how gym works dict kwargs need to be copied and hence can only be checked to have the same content
|
||||||
|
assert d == env.d
|
||||||
|
assert e is env.e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
|
||||||
|
@pytest.mark.parametrize('add_time_aware_wrapper_before', [True, False])
|
||||||
|
def test_learn_sub_trajectories(env_wrap: Tuple[str, Type[RawInterfaceWrapper]], add_time_aware_wrapper_before: bool):
|
||||||
|
env_id, wrapper_class = env_wrap
|
||||||
|
env_step = TimeAwareObservation(fancy_gym.make(env_id, SEED))
|
||||||
|
wrappers = [wrapper_class]
|
||||||
|
|
||||||
|
# has time aware wrapper
|
||||||
|
if add_time_aware_wrapper_before:
|
||||||
|
wrappers += [TimeAwareObservation]
|
||||||
|
|
||||||
|
env = fancy_gym.make_bb(env_id, [wrapper_class], {'learn_sub_trajectories': True},
|
||||||
|
{'trajectory_generator_type': 'promp'},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': 'linear'},
|
||||||
|
{'basis_generator_type': 'rbf'})
|
||||||
|
|
||||||
|
assert env.learn_sub_trajectories
|
||||||
|
assert env.traj_gen.learn_tau
|
||||||
|
assert env.observation_space == env_step.observation_space
|
||||||
|
|
||||||
|
env.reset()
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, r, d, info = env.step(action)
|
||||||
|
|
||||||
|
length = info['trajectory_length']
|
||||||
|
|
||||||
|
factor = 1 / env.dt
|
||||||
|
assert np.allclose(length * env.dt, np.round(factor * action[0]) / factor)
|
||||||
|
assert np.allclose(length * env.dt, np.round(factor * env.traj_gen.tau.numpy()) / factor)
|
73
test/test_controller.py
Normal file
73
test/test_controller.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from fancy_gym.black_box.factory import controller_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('ctrl_type', controller_factory.ALL_TYPES)
|
||||||
|
def test_initialization(ctrl_type: str):
|
||||||
|
controller_factory.get_controller(ctrl_type)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
def test_velocity(position: np.ndarray, velocity: np.ndarray):
|
||||||
|
ctrl = controller_factory.get_controller('velocity')
|
||||||
|
a = ctrl(position, velocity, None, None)
|
||||||
|
assert np.array_equal(a, velocity)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
def test_position(position: np.ndarray, velocity: np.ndarray):
|
||||||
|
ctrl = controller_factory.get_controller('position')
|
||||||
|
a = ctrl(position, velocity, None, None)
|
||||||
|
assert np.array_equal(a, position)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('current_velocity', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('p_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('d_gains', [0, 1, 0.5, np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
def test_pd(position: np.ndarray, velocity: np.ndarray, current_position: np.ndarray, current_velocity: np.ndarray,
|
||||||
|
p_gains: Union[float, Tuple], d_gains: Union[float, Tuple]):
|
||||||
|
ctrl = controller_factory.get_controller('motor', p_gains=p_gains, d_gains=d_gains)
|
||||||
|
assert np.array_equal(ctrl.p_gains, p_gains)
|
||||||
|
assert np.array_equal(ctrl.d_gains, d_gains)
|
||||||
|
|
||||||
|
a = ctrl(position, velocity, current_position, current_velocity)
|
||||||
|
pd = p_gains * (position - current_position) + d_gains * (velocity - current_velocity)
|
||||||
|
assert np.array_equal(a, pd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('pos_vel', [(np.ones(3, ), np.ones(4, )),
|
||||||
|
(np.ones(4, ), np.ones(3, )),
|
||||||
|
(np.ones(4, ), np.ones(4, ))])
|
||||||
|
def test_pd_invalid_shapes(pos_vel: Tuple[np.ndarray, np.ndarray]):
|
||||||
|
position, velocity = pos_vel
|
||||||
|
ctrl = controller_factory.get_controller('motor')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ctrl(position, velocity, np.ones(3, ), np.ones(3, ))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('current_position', [np.zeros(3, ), np.ones(3, ), np.arange(0, 3)])
|
||||||
|
@pytest.mark.parametrize('gripper_pos', [0, 1, 0.5])
|
||||||
|
def test_metaworld(position: np.ndarray, current_position: np.ndarray, gripper_pos: float):
|
||||||
|
ctrl = controller_factory.get_controller('metaworld')
|
||||||
|
|
||||||
|
position_grip = np.append(position, gripper_pos)
|
||||||
|
c_position_grip = np.append(current_position, -1)
|
||||||
|
a = ctrl(position_grip, None, c_position_grip, None)
|
||||||
|
assert a[-1] == gripper_pos
|
||||||
|
assert np.array_equal(a[:-1], position - current_position)
|
||||||
|
|
||||||
|
|
||||||
|
def test_metaworld_invalid_shapes():
|
||||||
|
ctrl = controller_factory.get_controller('metaworld')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ctrl(np.ones(4, ), None, np.ones(3, ), None)
|
Loading…
Reference in New Issue
Block a user