diff --git a/test/test_fancy_registry.py b/test/test_fancy_registry.py index 961d3b6..aad076b 100644 --- a/test/test_fancy_registry.py +++ b/test/test_fancy_registry.py @@ -9,8 +9,7 @@ from gymnasium.core import ActType, ObsType import fancy_gym from fancy_gym import register -ENV_IDS = ['fancy/Reacher5d-v0', 'dm_control/ball_in_cup-catch-v0', 'metaworld/reach-v2', 'Reacher-v2'] -KNOWN_NS = ['dm_controll', 'fancy', 'metaworld', 'gym'] +KNOWN_NS = ['dm_control', 'fancy', 'metaworld', 'gym'] class Object(object): @@ -41,33 +40,39 @@ class ToyEnv(gym.Env): @pytest.fixture(scope="session", autouse=True) def setup(): register( - id=f'toy2-v0', + id=f'dummy/toy2-v0', entry_point='test.test_black_box:ToyEnv', max_episode_steps=50, ) -@pytest.mark.parametrize('env_id', ENV_IDS) +@pytest.mark.parametrize('env_id', ['dummy/toy2-v0']) @pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP']) def test_make_mp(env_id: str, mp_type: str): - parts = env_id.split('-') - assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.' - fancy_id = '-'.join(parts[:-1]+[mp_type, parts[-1]]) + parts = env_id.split('/') + if len(parts) == 1: + ns, name = 'gym', parts[0] + elif len(parts) == 2: + ns, name = parts[0], parts[1] + else: + raise ValueError('env id can not contain multiple "/".') + + fancy_id = f'{ns}_{mp_type}/{name}' make(fancy_id) def test_make_raw_toy(): - make('toy2-v0') + make('dummy/toy2-v0') @pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP']) def test_make_mp_toy(mp_type: str): - fancy_id = '-'.join(['toy2', mp_type, 'v0']) + fancy_id = f'dummy_{mp_type}/toy2-v0' make(fancy_id) @pytest.mark.parametrize('ns', KNOWN_NS) def test_ns_nonempty(ns): - assert len(fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]), f'The namespace {ns} is empty even though, it should not be...' + assert len(fancy_gym.MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]), f'The namespace {ns} is empty even though, it should not be...'