Moving the ugly_mitigation_for_metaworld_bug into the metaworld env wrapper
This commit is contained in:
parent
f3ffa714cb
commit
5b99227fac
@ -24,7 +24,7 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
eos = env.observation_space
|
eos = env.observation_space
|
||||||
eas = env.observation_space
|
eas = env.action_space
|
||||||
|
|
||||||
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
|
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
|
||||||
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
|
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
|
||||||
@ -33,6 +33,23 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
|
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
ret = self.env.reset(**kwargs)
|
||||||
|
head = self.env
|
||||||
|
try:
|
||||||
|
for i in range(16):
|
||||||
|
head.curr_path_length = 0
|
||||||
|
head = head.env
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
||||||
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
||||||
@ -60,6 +77,8 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str]
|
|||||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
||||||
env = gym.make(gym_id, disable_env_checker=True)
|
env = gym.make(gym_id, disable_env_checker=True)
|
||||||
env = MujocoMapSpacesWrapper(env)
|
env = MujocoMapSpacesWrapper(env)
|
||||||
|
# TODO remove, when this has been fixed upstream
|
||||||
|
env = MitigateMetaworldBug(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user