From 9ce040d1101f2e3c3bcfb5aa5fe5a8c518f51ae9 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 18 Sep 2023 18:40:10 +0200 Subject: [PATCH] Porting Metaworld Bug Mitigations --- fancy_gym/meta/metaworld_adapter.py | 43 ++++++++++------------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/fancy_gym/meta/metaworld_adapter.py b/fancy_gym/meta/metaworld_adapter.py index ed2b5b6..71e8ef0 100644 --- a/fancy_gym/meta/metaworld_adapter.py +++ b/fancy_gym/meta/metaworld_adapter.py @@ -1,3 +1,4 @@ +import random from typing import Iterable, Type, Union, Optional import numpy as np @@ -13,12 +14,10 @@ from fancy_gym.utils.env_compatibility import EnvCompatibility try: import metaworld except Exception: - # catch Exception as Import error does not catch missing mujoco-py - # TODO: Print info? - pass + print('[FANCY GYM] Metaworld not avaible') -class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): +class FixMetaworldHasIncorrectObsSpaceWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): def __init__(self, env: gym.Env): gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) @@ -29,11 +28,11 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1]) Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1]) - self.observation_space = Obs_Space_Class(low=eos.low, high=eos.high, dtype=eos.dtype) + self.observation_space = Obs_Space_Class(low=eos.low-np.inf, high=eos.high+np.inf, dtype=eos.dtype) self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype) -class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs): +class FixMetaworldIncorrectResetPathLengthWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): def __init__(self, env: gym.Env): gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) @@ -50,13 +49,12 @@ class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs): return ret -class MetaworldResetFix(gym.Wrapper, gym.utils.RecordConstructorArgs): +class FixMetaworldIgnoresSeedOnResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): def __init__(self, env: gym.Env): gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) def reset(self, **kwargs): - self.env.reset(**kwargs) if 'seed' in kwargs: self.env.seed(kwargs['seed']) return self.env.reset(**kwargs) @@ -66,32 +64,19 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] if underlying_id not in metaworld.ML1.ENV_NAMES: raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.') - _env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs) + env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs) # setting this avoids generating the same initialization after each reset - _env._freeze_rand_vec = False + env._freeze_rand_vec = False # New argument to use global seeding - _env.seeded_rand_vec = True + env.seeded_rand_vec = True - max_episode_steps = _env.max_path_length - - # TODO remove this as soon as there is support for the new API - _env = EnvCompatibility(_env, render_mode) - env = _env - - # gym_id = '_metaworld_compat_' + uuid.uuid4().hex + '-v0' - # gym_register( - # id=gym_id, - # entry_point=lambda: _env, - # max_episode_steps=max_episode_steps, - # ) - - # TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld - # env = gym.make(gym_id, disable_env_checker=True) - env = MujocoMapSpacesWrapper(env) # TODO remove, when this has been fixed upstream - env = MitigateMetaworldBug(env) - env = MetaworldResetFix(env) + env = FixMetaworldHasIncorrectObsSpaceWrapper(env) + # TODO remove, when this has been fixed upstream + # env = FixMetaworldIncorrectResetPathLengthWrapper(env) + # TODO remove, when this has been fixed upstream + env = FixMetaworldIgnoresSeedOnResetWrapper(env) return env