Fix: Multiple issues in test/test_fancy_registry.py
This commit is contained in:
parent
14d545acee
commit
1fb5368cc2
@ -9,8 +9,7 @@ from gymnasium.core import ActType, ObsType
|
|||||||
import fancy_gym
|
import fancy_gym
|
||||||
from fancy_gym import register
|
from fancy_gym import register
|
||||||
|
|
||||||
ENV_IDS = ['fancy/Reacher5d-v0', 'dm_control/ball_in_cup-catch-v0', 'metaworld/reach-v2', 'Reacher-v2']
|
KNOWN_NS = ['dm_control', 'fancy', 'metaworld', 'gym']
|
||||||
KNOWN_NS = ['dm_controll', 'fancy', 'metaworld', 'gym']
|
|
||||||
|
|
||||||
|
|
||||||
class Object(object):
|
class Object(object):
|
||||||
@ -41,33 +40,39 @@ class ToyEnv(gym.Env):
|
|||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def setup():
|
def setup():
|
||||||
register(
|
register(
|
||||||
id=f'toy2-v0',
|
id=f'dummy/toy2-v0',
|
||||||
entry_point='test.test_black_box:ToyEnv',
|
entry_point='test.test_black_box:ToyEnv',
|
||||||
max_episode_steps=50,
|
max_episode_steps=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('env_id', ENV_IDS)
|
@pytest.mark.parametrize('env_id', ['dummy/toy2-v0'])
|
||||||
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
|
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
|
||||||
def test_make_mp(env_id: str, mp_type: str):
|
def test_make_mp(env_id: str, mp_type: str):
|
||||||
parts = env_id.split('-')
|
parts = env_id.split('/')
|
||||||
assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.'
|
if len(parts) == 1:
|
||||||
fancy_id = '-'.join(parts[:-1]+[mp_type, parts[-1]])
|
ns, name = 'gym', parts[0]
|
||||||
|
elif len(parts) == 2:
|
||||||
|
ns, name = parts[0], parts[1]
|
||||||
|
else:
|
||||||
|
raise ValueError('env id can not contain multiple "/".')
|
||||||
|
|
||||||
|
fancy_id = f'{ns}_{mp_type}/{name}'
|
||||||
|
|
||||||
make(fancy_id)
|
make(fancy_id)
|
||||||
|
|
||||||
|
|
||||||
def test_make_raw_toy():
|
def test_make_raw_toy():
|
||||||
make('toy2-v0')
|
make('dummy/toy2-v0')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
|
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
|
||||||
def test_make_mp_toy(mp_type: str):
|
def test_make_mp_toy(mp_type: str):
|
||||||
fancy_id = '-'.join(['toy2', mp_type, 'v0'])
|
fancy_id = f'dummy_{mp_type}/toy2-v0'
|
||||||
|
|
||||||
make(fancy_id)
|
make(fancy_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('ns', KNOWN_NS)
|
@pytest.mark.parametrize('ns', KNOWN_NS)
|
||||||
def test_ns_nonempty(ns):
|
def test_ns_nonempty(ns):
|
||||||
assert len(fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]), f'The namespace {ns} is empty even though, it should not be...'
|
assert len(fancy_gym.MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]), f'The namespace {ns} is empty even though, it should not be...'
|
||||||
|
Loading…
Reference in New Issue
Block a user