Bug mitigation for metaworld refactored and extended

This commit is contained in:
Dominik Moritz Roth 2023-06-18 17:47:54 +02:00
parent b6089c4b83
commit 7354257f8e
3 changed files with 15 additions and 11 deletions

View File

@ -10,6 +10,7 @@ from gymnasium.core import ActType, ObsType
import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.wrappers import TimeAwareObservation
from test.utils import ugly_hack_to_mitigate_metaworld_bug
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
@ -128,6 +129,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
for i in range(5):
env.reset(seed=SEED)
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
length = info['trajectory_length']
@ -330,6 +332,7 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
for i in range(5):
if done:
env.reset(seed=SEED)
ugly_hack_to_mitigate_metaworld_bug(env)
action = env.action_space.sample()
action[0] = tau
action[1] = delay

View File

@ -13,6 +13,7 @@ import fancy_gym
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.wrappers import TimeAwareObservation
from fancy_gym.utils.make_env_helpers import ensure_finite_time
from test.utils import ugly_hack_to_mitigate_metaworld_bug
SEED = 1
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
@ -155,21 +156,11 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
print(done, (i + 1), episode_steps)
assert (i + 1) % episode_steps == 0
env.reset(seed=SEED)
ugly_hack_to_mitigate_metaworld_bug(env)
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
assert replanning_schedule(None, None, None, None, length)
def ugly_hack_to_mitigate_metaworld_bug(env):
head = env
try:
for i in range(16):
head.curr_path_length = 0
head = head.env
except:
pass
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
@pytest.mark.parametrize('sub_segment_steps', [5, 10])

View File

@ -100,3 +100,13 @@ def verify_reward(reward):
def verify_done(done):
assert isinstance(
done, bool), f"Returned {done} as done flag, expected bool."
def ugly_hack_to_mitigate_metaworld_bug(env):
head = env
try:
for i in range(16):
head.curr_path_length = 0
head = head.env
except:
pass