ugly_hack_to_mitigate_metaworld_bug

This commit is contained in:
Dominik Moritz Roth 2023-06-18 15:52:17 +02:00
parent f8ad65b790
commit b6089c4b83
2 changed files with 13 additions and 2 deletions

View File

@ -17,7 +17,7 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
MAX_STEPS_FALLBACK = 500
MAX_STEPS_FALLBACK = 100
class Object(object):

View File

@ -20,7 +20,7 @@ WRAPPERS = [fancy_gym.envs.mujoco.reacher.MPWrapper, fancy_gym.dmc.suite.ball_in
fancy_gym.meta.goal_object_change_mp_wrapper.MPWrapper, fancy_gym.open_ai.mujoco.reacher_v2.MPWrapper]
ALL_MP_ENVS = chain(*fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
MAX_STEPS_FALLBACK = 100
MAX_STEPS_FALLBACK = 50
class ToyEnv(gym.Env):
@ -155,10 +155,21 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
print(done, (i + 1), episode_steps)
assert (i + 1) % episode_steps == 0
env.reset(seed=SEED)
ugly_hack_to_mitigate_metaworld_bug(env)
assert replanning_schedule(None, None, None, None, length)
def ugly_hack_to_mitigate_metaworld_bug(env):
head = env
try:
for i in range(16):
head.curr_path_length = 0
head = head.env
except:
pass
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
@pytest.mark.parametrize('sub_segment_steps', [5, 10])