Bug mitigation for metaworld refactored and extended
This commit is contained in:
parent
b6089c4b83
commit
7354257f8e
@ -10,6 +10,7 @@ from gymnasium.core import ActType, ObsType
|
|||||||
import fancy_gym
|
import fancy_gym
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||||
|
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
||||||
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||||
@ -128,6 +129,7 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
|
|||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
|
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
|
||||||
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
_obs, _reward, _terminated, _truncated, info = env.step(env.action_space.sample())
|
||||||
length = info['trajectory_length']
|
length = info['trajectory_length']
|
||||||
|
|
||||||
@ -330,6 +332,7 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
|||||||
for i in range(5):
|
for i in range(5):
|
||||||
if done:
|
if done:
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
|
ugly_hack_to_mitigate_metaworld_bug(env)
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
action[0] = tau
|
action[0] = tau
|
||||||
action[1] = delay
|
action[1] = delay
|
||||||
|
@ -13,6 +13,7 @@ import fancy_gym
|
|||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from fancy_gym.utils.wrappers import TimeAwareObservation
|
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||||
from fancy_gym.utils.make_env_helpers import ensure_finite_time
|
from fancy_gym.utils.make_env_helpers import ensure_finite_time
|
||||||
|
from test.utils import ugly_hack_to_mitigate_metaworld_bug
|
||||||
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
ENV_IDS = ['Reacher5d-v0', 'dmc:ball_in_cup-catch-v0', 'metaworld:reach-v2', 'Reacher-v2']
|
||||||
@ -155,21 +156,11 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
|
|||||||
print(done, (i + 1), episode_steps)
|
print(done, (i + 1), episode_steps)
|
||||||
assert (i + 1) % episode_steps == 0
|
assert (i + 1) % episode_steps == 0
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
ugly_hack_to_mitigate_metaworld_bug(env)
|
ugly_hack_to_mitigate_metaworld_bug(env) # TODO: Remove, when metaworld fixed it upstream
|
||||||
|
|
||||||
assert replanning_schedule(None, None, None, None, length)
|
assert replanning_schedule(None, None, None, None, length)
|
||||||
|
|
||||||
|
|
||||||
def ugly_hack_to_mitigate_metaworld_bug(env):
|
|
||||||
head = env
|
|
||||||
try:
|
|
||||||
for i in range(16):
|
|
||||||
head.curr_path_length = 0
|
|
||||||
head = head.env
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
|
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
|
||||||
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
|
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
|
||||||
@pytest.mark.parametrize('sub_segment_steps', [5, 10])
|
@pytest.mark.parametrize('sub_segment_steps', [5, 10])
|
||||||
|
@ -100,3 +100,13 @@ def verify_reward(reward):
|
|||||||
def verify_done(done):
|
def verify_done(done):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
done, bool), f"Returned {done} as done flag, expected bool."
|
done, bool), f"Returned {done} as done flag, expected bool."
|
||||||
|
|
||||||
|
|
||||||
|
def ugly_hack_to_mitigate_metaworld_bug(env):
|
||||||
|
head = env
|
||||||
|
try:
|
||||||
|
for i in range(16):
|
||||||
|
head.curr_path_length = 0
|
||||||
|
head = head.env
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user