Porting Metaworld Bug Mitigations
This commit is contained in:
parent
20510d8f68
commit
9ce040d110
@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
from typing import Iterable, Type, Union, Optional
|
from typing import Iterable, Type, Union, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,12 +14,10 @@ from fancy_gym.utils.env_compatibility import EnvCompatibility
|
|||||||
try:
|
try:
|
||||||
import metaworld
|
import metaworld
|
||||||
except Exception:
|
except Exception:
|
||||||
# catch Exception as Import error does not catch missing mujoco-py
|
print('[FANCY GYM] Metaworld not avaible')
|
||||||
# TODO: Print info?
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class FixMetaworldHasIncorrectObsSpaceWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
def __init__(self, env: gym.Env):
|
def __init__(self, env: gym.Env):
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
@ -29,11 +28,11 @@ class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
|
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
|
||||||
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
|
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
|
||||||
|
|
||||||
self.observation_space = Obs_Space_Class(low=eos.low, high=eos.high, dtype=eos.dtype)
|
self.observation_space = Obs_Space_Class(low=eos.low-np.inf, high=eos.high+np.inf, dtype=eos.dtype)
|
||||||
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
|
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class FixMetaworldIncorrectResetPathLengthWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
def __init__(self, env: gym.Env):
|
def __init__(self, env: gym.Env):
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
@ -50,13 +49,12 @@ class MitigateMetaworldBug(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class MetaworldResetFix(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
class FixMetaworldIgnoresSeedOnResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
def __init__(self, env: gym.Env):
|
def __init__(self, env: gym.Env):
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
self.env.reset(**kwargs)
|
|
||||||
if 'seed' in kwargs:
|
if 'seed' in kwargs:
|
||||||
self.env.seed(kwargs['seed'])
|
self.env.seed(kwargs['seed'])
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
@ -66,32 +64,19 @@ def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str]
|
|||||||
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
||||||
|
|
||||||
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs)
|
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs)
|
||||||
|
|
||||||
# setting this avoids generating the same initialization after each reset
|
# setting this avoids generating the same initialization after each reset
|
||||||
_env._freeze_rand_vec = False
|
env._freeze_rand_vec = False
|
||||||
# New argument to use global seeding
|
# New argument to use global seeding
|
||||||
_env.seeded_rand_vec = True
|
env.seeded_rand_vec = True
|
||||||
|
|
||||||
max_episode_steps = _env.max_path_length
|
|
||||||
|
|
||||||
# TODO remove this as soon as there is support for the new API
|
|
||||||
_env = EnvCompatibility(_env, render_mode)
|
|
||||||
env = _env
|
|
||||||
|
|
||||||
# gym_id = '_metaworld_compat_' + uuid.uuid4().hex + '-v0'
|
|
||||||
# gym_register(
|
|
||||||
# id=gym_id,
|
|
||||||
# entry_point=lambda: _env,
|
|
||||||
# max_episode_steps=max_episode_steps,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
|
||||||
# env = gym.make(gym_id, disable_env_checker=True)
|
|
||||||
env = MujocoMapSpacesWrapper(env)
|
|
||||||
# TODO remove, when this has been fixed upstream
|
# TODO remove, when this has been fixed upstream
|
||||||
env = MitigateMetaworldBug(env)
|
env = FixMetaworldHasIncorrectObsSpaceWrapper(env)
|
||||||
env = MetaworldResetFix(env)
|
# TODO remove, when this has been fixed upstream
|
||||||
|
# env = FixMetaworldIncorrectResetPathLengthWrapper(env)
|
||||||
|
# TODO remove, when this has been fixed upstream
|
||||||
|
env = FixMetaworldIgnoresSeedOnResetWrapper(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user