Started work on new adapter & mp_config port for metaworld

This commit is contained in:
Dominik Moritz Roth 2023-07-23 12:21:34 +02:00
parent 3e586a1407
commit 99a02b8347
3 changed files with 86 additions and 60 deletions

View File

@ -1,3 +1,5 @@
from typing import Iterable, Type, Union, Optional
from copy import deepcopy
from gymnasium import register
@ -5,6 +7,10 @@ from gymnasium import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
object_change_mp_wrapper
from . import metaworld_adapter
metaworld_adapter.register_all_ML1()
ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
# MetaWorld

View File

@ -0,0 +1,78 @@
import numpy as np
from gymnasium import register as gym_register
from fancy_gym import register
import uuid
import gymnasium as gym
import numpy as np
from fancy_gym.utils.env_compatibility import EnvCompatibility
try:
import metaworld
except Exception:
# catch Exception as Import error does not catch missing mujoco-py
# TODO: Print info?
pass
class MujocoMapSpacesWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(self, env: gym.Env):
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
eos = env.observation_space
eas = env.observation_space
Obs_Space_Class = getattr(gym.spaces, str(eos.__class__).split("'")[1].split('.')[-1])
Act_Space_Class = getattr(gym.spaces, str(eas.__class__).split("'")[1].split('.')[-1])
self.observation_space = Obs_Space_Class(low=eos.low, high=eos.high, dtype=eos.dtype)
self.action_space = Act_Space_Class(low=eas.low, high=eas.high, dtype=eas.dtype)
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
if underlying_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs)
# setting this avoids generating the same initialization after each reset
_env._freeze_rand_vec = False
# New argument to use global seeding
_env.seeded_rand_vec = True
max_episode_steps = _env.max_path_length
# TODO remove this as soon as there is support for the new API
_env = EnvCompatibility(_env, render_mode)
gym_id = '_metaworld_compat_' + uuid.uuid4().hex + '-v0'
gym_register(
id=gym_id,
entry_point=lambda: _env,
max_episode_steps=max_episode_steps,
)
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
env = gym.make(gym_id, disable_env_checker=True)
env = MujocoMapSpacesWrapper(env)
return env
def register_all_ML1(**kwargs):
for env_id in metaworld.ML1.ENV_NAMES:
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=0)
max_episode_steps = _env.max_path_length
gym_register(
id='metaworld/'+env_id,
entry_point=make_metaworld,
max_episode_steps=max_episode_steps,
kwargs={
'underlying_id': env_id
},
**kwargs
)

View File

@ -15,7 +15,6 @@ from gymnasium import make
import numpy as np
from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import TimeLimit
from gymnasium import make as gym_make
from fancy_gym.utils.env_compatibility import EnvCompatibility
from fancy_gym.utils.wrappers import FlattenObservation
@ -63,7 +62,7 @@ def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=
def make_bb(
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
time_limit: int, fallback_max_steps: int = None):
time_limit: int = None, fallback_max_steps: int = None, **kwargs):
"""
This can also be used standalone for manually building a custom DMP environment.
Args:
@ -92,7 +91,7 @@ def make_bb(
wrappers.insert(0, TimeAwareObservation)
if isinstance(env, str):
env = make(env)
env = make(env, **kwargs)
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps)
@ -154,63 +153,6 @@ def get_env_duration(env: gym.Env):
return duration
def make(env_id: str, **kwargs):
"""
Converts an env_id to an environment with the gym API.
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
specified with "dmc/domain_name-task_name"
Analogously, metaworld tasks can be created as "metaworld/env_id-v2".
Args:
env_id: spec or env_id for gym tasks, external environments require a domain specification
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
Returns: Gym environment
"""
if env_id.startswith('metaworld'):
env = make_metaworld(env_id.replace('metaworld', '')[1:], **kwargs)
env = gym_make(env_id, **kwargs)
if not env.spec.max_episode_steps == None:
# Hack: Some envs violate the gym spec in that they don't correctly expose the maximum episode steps
# Gymnasium disallows accessing private attributes, so we have to get creative to read the internal values
# TODO: Remove this, when all supported envs correctly implement this themselves
unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
if hasattr(unwrapped, '_max_episode_steps'):
env.spec.max_episode_steps = unwrapped.__getattribute__('_max_episode_steps')
return env
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
if env_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
_env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
# setting this avoids generating the same initialization after each reset
_env._freeze_rand_vec = False
# New argument to use global seeding
_env.seeded_rand_vec = True
max_episode_steps = _env.max_path_length
# TODO remove this as soon as there is support for the new API
_env = EnvCompatibility(_env, render_mode)
gym_id = uuid.uuid4().hex + '-v1'
register(
id=gym_id,
entry_point=lambda: _env,
max_episode_steps=max_episode_steps,
)
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
env = gym.make(gym_id, disable_env_checker=True)
return env
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
"""
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.