Started work on new adapter & mp_config port for metaworld
This commit is contained in:
parent
3e586a1407
commit
99a02b8347
@ -1,3 +1,5 @@
|
||||
from typing import Iterable, Type, Union, Optional
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium import register
|
||||
@ -5,6 +7,10 @@ from gymnasium import register
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
||||
from . import metaworld_adapter
|
||||
|
||||
metaworld_adapter.register_all_ML1()
|
||||
|
||||
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
# MetaWorld
|
||||
|
78
fancy_gym/meta/metaworld_adapter.py
Normal file
78
fancy_gym/meta/metaworld_adapter.py
Normal file
@ -0,0 +1,78 @@
|
||||
import numpy as np
|
||||
from gymnasium import register as gym_register
|
||||
from fancy_gym import register
|
||||
|
||||
import uuid
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||
|
||||
try:
|
||||
import metaworld
|
||||
except Exception:
|
||||
# catch Exception as Import error does not catch missing mujoco-py
|
||||
# TODO: Print info?
|
||||
pass
|
||||
|
||||
|
||||
class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||
def __init__(self, env: gym.Env):
|
||||
gym.utils.RecordConstructorArgs.__init__(self)
|
||||
gym.Wrapper.__init__(self, env)
|
||||
|
||||
eos = env.observation_space
|
||||
eas = env.observation_space
|
||||
|
||||
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
|
||||
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
|
||||
|
||||
self.observation_space = Obs_Space_Class(low=eos.low, high=eos.high, dtype=eos.dtype)
|
||||
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
|
||||
|
||||
|
||||
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
||||
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
||||
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
||||
|
||||
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs)
|
||||
|
||||
# setting this avoids generating the same initialization after each reset
|
||||
_env._freeze_rand_vec = False
|
||||
# New argument to use global seeding
|
||||
_env.seeded_rand_vec = True
|
||||
|
||||
max_episode_steps = _env.max_path_length
|
||||
|
||||
# TODO remove this as soon as there is support for the new API
|
||||
_env = EnvCompatibility(_env, render_mode)
|
||||
|
||||
gym_id = '_metaworld_compat_' + uuid.uuid4().hex + '-v0'
|
||||
|
||||
gym_register(
|
||||
id=gym_id,
|
||||
entry_point=lambda: _env,
|
||||
max_episode_steps=max_episode_steps,
|
||||
)
|
||||
|
||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
||||
env = gym.make(gym_id, disable_env_checker=True)
|
||||
env = MujocoMapSpacesWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
def register_all_ML1(**kwargs):
|
||||
for env_id in metaworld.ML1.ENV_NAMES:
|
||||
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=0)
|
||||
max_episode_steps = _env.max_path_length
|
||||
|
||||
gym_register(
|
||||
id='metaworld/'+env_id,
|
||||
entry_point=make_metaworld,
|
||||
max_episode_steps=max_episode_steps,
|
||||
kwargs={
|
||||
'underlying_id': env_id
|
||||
},
|
||||
**kwargs
|
||||
)
|
@ -15,7 +15,6 @@ from gymnasium import make
|
||||
import numpy as np
|
||||
from gymnasium.envs.registration import register, registry
|
||||
from gymnasium.wrappers import TimeLimit
|
||||
from gymnasium import make as gym_make
|
||||
|
||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||
from fancy_gym.utils.wrappers import FlattenObservation
|
||||
@ -63,7 +62,7 @@ def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=
|
||||
def make_bb(
|
||||
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
||||
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
|
||||
time_limit: int, fallback_max_steps: int = None):
|
||||
time_limit: int = None, fallback_max_steps: int = None, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom DMP environment.
|
||||
Args:
|
||||
@ -92,7 +91,7 @@ def make_bb(
|
||||
wrappers.insert(0, TimeAwareObservation)
|
||||
|
||||
if isinstance(env, str):
|
||||
env = make(env)
|
||||
env = make(env, **kwargs)
|
||||
|
||||
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps)
|
||||
|
||||
@ -154,63 +153,6 @@ def get_env_duration(env: gym.Env):
|
||||
return duration
|
||||
|
||||
|
||||
def make(env_id: str, **kwargs):
|
||||
"""
|
||||
Converts an env_id to an environment with the gym API.
|
||||
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
|
||||
specified with "dmc/domain_name-task_name"
|
||||
Analogously, metaworld tasks can be created as "metaworld/env_id-v2".
|
||||
Args:
|
||||
env_id: spec or env_id for gym tasks, external environments require a domain specification
|
||||
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
|
||||
Returns: Gym environment
|
||||
"""
|
||||
|
||||
if env_id.startswith('metaworld'):
|
||||
env = make_metaworld(env_id.replace('metaworld', '')[1:], **kwargs)
|
||||
|
||||
env = gym_make(env_id, **kwargs)
|
||||
|
||||
if not env.spec.max_episode_steps == None:
|
||||
# Hack: Some envs violate the gym spec in that they don't correctly expose the maximum episode steps
|
||||
# Gymnasium disallows accessing private attributes, so we have to get creative to read the internal values
|
||||
# TODO: Remove this, when all supported envs correctly implement this themselves
|
||||
unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
|
||||
if hasattr(unwrapped, '_max_episode_steps'):
|
||||
env.spec.max_episode_steps = unwrapped.__getattribute__('_max_episode_steps')
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||
|
||||
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
|
||||
|
||||
# setting this avoids generating the same initialization after each reset
|
||||
_env._freeze_rand_vec = False
|
||||
# New argument to use global seeding
|
||||
_env.seeded_rand_vec = True
|
||||
|
||||
max_episode_steps = _env.max_path_length
|
||||
|
||||
# TODO remove this as soon as there is support for the new API
|
||||
_env = EnvCompatibility(_env, render_mode)
|
||||
|
||||
gym_id = uuid.uuid4().hex + '-v1'
|
||||
|
||||
register(
|
||||
id=gym_id,
|
||||
entry_point=lambda: _env,
|
||||
max_episode_steps=max_episode_steps,
|
||||
)
|
||||
|
||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
||||
env = gym.make(gym_id, disable_env_checker=True)
|
||||
return env
|
||||
|
||||
|
||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||
"""
|
||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||
|
Loading…
Reference in New Issue
Block a user