diff --git a/test/test_black_box.py b/test/test_black_box.py index 139b1c2..76bd73e 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -4,7 +4,7 @@ from typing import Tuple, Type, Union, Optional, Callable import gymnasium as gym import numpy as np import pytest -from gymnasium import register +from gymnasium import register, make from gymnasium.core import ActType, ObsType import fancy_gym @@ -13,7 +13,7 @@ from fancy_gym.utils.wrappers import TimeAwareObservation from test.utils import ugly_hack_to_mitigate_metaworld_bug SEED = 1 -ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2'] +ENV_IDS = ['Reacher5d-v0', 'dm_control/ball_in_cup-catch-v0', 'metaworld/reach-v2', 'Reacher-v2'] WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in_cup.MPWrapper, fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper] ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) @@ -102,7 +102,7 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]] _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) info_keys = list(info.keys()) - env_step = fancy_gym.make(env_id, SEED) + env_step = make(env_id) env_step.reset() _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) info_keys_step = info.keys() @@ -161,7 +161,7 @@ def test_context_space(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapp {'phase_generator_type': 'exp'}, {'basis_generator_type': 'rbf'}) # check if observation space matches with the specified mask values which are true - env_step = fancy_gym.make(env_id, SEED) + env_step = make(env_id) wrapper = wrapper_class(env_step) assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape @@ -231,8 +231,9 @@ def test_learn_tau(mp_type: str, tau: float): 'learn_delay': False }, {'basis_generator_type': basis_generator_type, - }, seed=SEED) + }) + env.reset(seed=SEED) done = True for i in range(5): if done: @@ -277,8 +278,9 @@ def test_learn_delay(mp_type: str, delay: float): 'learn_delay': True }, {'basis_generator_type': basis_generator_type, - }, seed=SEED) + }) + env.reset(seed=SEED) done = True for i in range(5): if done: @@ -323,7 +325,9 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float): 'learn_delay': True }, {'basis_generator_type': basis_generator_type, - }, seed=SEED) + }) + + env.reset(seed=SEED) if env.spec.max_episode_steps * env.dt < delay + tau: return