fancy_gym/test/test_fancy_registry.py

79 lines
2.2 KiB
Python
Raw Normal View History

2023-07-23 12:21:18 +02:00
from typing import Tuple, Type, Union, Optional, Callable
import gymnasium as gym
import numpy as np
import pytest
from gymnasium import make
from gymnasium.core import ActType, ObsType
import fancy_gym
from fancy_gym import register
KNOWN_NS = ['dm_control', 'fancy', 'metaworld', 'gym']
2023-07-23 12:21:18 +02:00
class Object(object):
pass
class ToyEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
dt = 0.02
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
obs, options = np.array([-1]), {}
return obs, options
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
return obs, reward, terminated, truncated, info
def render(self, mode="human"):
pass
@pytest.fixture(scope="session", autouse=True)
def setup():
register(
id=f'dummy/toy2-v0',
2023-07-23 12:21:18 +02:00
entry_point='test.test_black_box:ToyEnv',
max_episode_steps=50,
)
@pytest.mark.parametrize('env_id', ['dummy/toy2-v0'])
2023-07-23 13:13:02 +02:00
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
2023-07-23 12:21:18 +02:00
def test_make_mp(env_id: str, mp_type: str):
parts = env_id.split('/')
if len(parts) == 1:
ns, name = 'gym', parts[0]
elif len(parts) == 2:
ns, name = parts[0], parts[1]
else:
raise ValueError('env id can not contain multiple "/".')
fancy_id = f'{ns}_{mp_type}/{name}'
2023-07-23 12:21:18 +02:00
make(fancy_id)
def test_make_raw_toy():
make('dummy/toy2-v0')
2023-07-23 12:21:18 +02:00
2023-07-23 13:13:02 +02:00
@pytest.mark.parametrize('mp_type', ['ProMP', 'DMP', 'ProDMP'])
2023-07-23 12:21:18 +02:00
def test_make_mp_toy(mp_type: str):
fancy_id = f'dummy_{mp_type}/toy2-v0'
2023-07-23 12:21:18 +02:00
make(fancy_id)
@pytest.mark.parametrize('ns', KNOWN_NS)
def test_ns_nonempty(ns):
assert len(fancy_gym.MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]), f'The namespace {ns} is empty even though, it should not be...'