diff --git a/test/test_black_box.py b/test/test_black_box.py index 53b4434..3f87375 100644 --- a/test/test_black_box.py +++ b/test/test_black_box.py @@ -10,6 +10,7 @@ from gymnasium.core import ActType, ObsType import fancy_gym from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from fancy_gym.utils.wrappers import TimeAwareObservation +from test.utils import ugly_hack_to_mitigate_metaworld_bug SEED = 1 ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2'] @@ -128,6 +129,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]): for i in range(5): env.reset(seed=SEED) + ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream _obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample()) length = info['trajectory_length'] @@ -330,6 +332,7 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float): for i in range(5): if done: env.reset(seed=SEED) + ugly_hack_to_mitigate_metaworld_bug(env) action = env.action_space.sample() action[0] = tau action[1] = delay diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py index 49c0218..e38fbd5 100644 --- a/test/test_replanning_sequencing.py +++ b/test/test_replanning_sequencing.py @@ -13,6 +13,7 @@ import fancy_gym from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from fancy_gym.utils.wrappers import TimeAwareObservation from fancy_gym.utils.make_env_helpers import ensure_finite_time +from test.utils import ugly_hack_to_mitigate_metaworld_bug SEED = 1 ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2'] @@ -155,21 +156,11 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra print(done, (i + 1), episode_steps) assert (i + 1) % episode_steps == 0 env.reset(seed=SEED) - ugly_hack_to_mitigate_metaworld_bug(env) + ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream assert replanning_schedule(None, None, None, None, length) -def ugly_hack_to_mitigate_metaworld_bug(env): - head = env - try: - for i in range(16): - head.curr_path_length = 0 - head = head.env - except: - pass - - @pytest.mark.parametrize('mp_type', ['promp', 'prodmp']) @pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4]) @pytest.mark.parametrize('sub_segment_steps', [5, 10]) diff --git a/test/utils.py b/test/utils.py index 157f840..782b151 100644 --- a/test/utils.py +++ b/test/utils.py @@ -100,3 +100,13 @@ def verify_reward(reward): def verify_done(done): assert isinstance( done, bool), f"Returned {done} as done flag, expected bool." + + +def ugly_hack_to_mitigate_metaworld_bug(env): + head = env + try: + for i in range(16): + head.curr_path_length = 0 + head = head.env + except: + pass