diff --git a/fancy_gym/meta/metaworld_adapter.py b/fancy_gym/meta/metaworld_adapter.py index b0dda4d..6f5859f 100644 --- a/fancy_gym/meta/metaworld_adapter.py +++ b/fancy_gym/meta/metaworld_adapter.py @@ -24,7 +24,7 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): gym.Wrapper.__init__(self, env) eos = env.observation_space - eas = env.observation_space + eas = env.action_space Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1]) Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1]) @@ -33,6 +33,23 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype) +class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs): + def __init__(self, env: gym.Env): + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + + def reset(self, **kwargs): + ret = self.env.reset(**kwargs) + head = self.env + try: + for i in range(16): + head.curr_path_length = 0 + head = head.env + except: + pass + return ret + + def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs): if underlying_id not in metaworld.ML1.ENV_NAMES: raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.') @@ -60,6 +77,8 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] # TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld env = gym.make(gym_id, disable_env_checker=True) env = MujocoMapSpacesWrapper(env) + # TODO remove, when this has been fixed upstream + env = MitigateMetaworldBug(env) return env