Bug mitigation for metaworld refactored and extended
This commit is contained in:
parent
b6089c4b83
commit
7354257f8e
@ -10,6 +10,7 @@ from gymnasium.core import ActType, ObsType
|
||||
import fancy_gym
|
||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
||||
|
||||
SEED = 1
|
||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||
@ -128,6 +129,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
||||
|
||||
for i in range(5):
|
||||
env.reset(seed=SEED)
|
||||
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
|
||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||
length = info['trajectory_length']
|
||||
|
||||
@ -330,6 +332,7 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
||||
for i in range(5):
|
||||
if done:
|
||||
env.reset(seed=SEED)
|
||||
ugly_hack_to_mitigate_metaworld_bug(env)
|
||||
action = env.action_space.sample()
|
||||
action[0] = tau
|
||||
action[1] = delay
|
||||
|
@ -13,6 +13,7 @@ import fancy_gym
|
||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||
from fancy_gym.utils.make_env_helpers import ensure_finite_time
|
||||
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
||||
|
||||
SEED = 1
|
||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||
@ -155,21 +156,11 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
||||
print(done, (i + 1), episode_steps)
|
||||
assert (i + 1) % episode_steps == 0
|
||||
env.reset(seed=SEED)
|
||||
ugly_hack_to_mitigate_metaworld_bug(env)
|
||||
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
|
||||
|
||||
assert replanning_schedule(None, None, None, None, length)
|
||||
|
||||
|
||||
def ugly_hack_to_mitigate_metaworld_bug(env):
|
||||
head = env
|
||||
try:
|
||||
for i in range(16):
|
||||
head.curr_path_length = 0
|
||||
head = head.env
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
|
||||
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
|
||||
@pytest.mark.parametrize('sub_segment_steps', [5, 10])
|
||||
|
@ -100,3 +100,13 @@ def verify_reward(reward):
|
||||
def verify_done(done):
|
||||
assert isinstance(
|
||||
done, bool), f"Returned {done} as done flag, expected bool."
|
||||
|
||||
|
||||
def ugly_hack_to_mitigate_metaworld_bug(env):
|
||||
head = env
|
||||
try:
|
||||
for i in range(16):
|
||||
head.curr_path_length = 0
|
||||
head = head.env
|
||||
except:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user