minor refactoring
This commit is contained in:
parent
30bafd7a4f
commit
9fa932d2bb
@ -15,6 +15,7 @@ from gymnasium import make
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.registration import register, registry
|
from gymnasium.envs.registration import register, registry
|
||||||
from gymnasium.wrappers import TimeLimit
|
from gymnasium.wrappers import TimeLimit
|
||||||
|
from gymnasium import make as gym_make
|
||||||
|
|
||||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||||
from fancy_gym.utils.wrappers import FlattenObservation
|
from fancy_gym.utils.wrappers import FlattenObservation
|
||||||
@ -32,7 +33,7 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs):
|
def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None):
|
||||||
"""
|
"""
|
||||||
Helper function for creating a wrapped gym environment using MPs.
|
Helper function for creating a wrapped gym environment using MPs.
|
||||||
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
|
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
|
||||||
@ -62,7 +63,7 @@ def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=
|
|||||||
def make_bb(
|
def make_bb(
|
||||||
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
||||||
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
|
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
|
||||||
fallback_max_steps: int = None, **kwargs):
|
time_limit: int, fallback_max_steps: int = None):
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom DMP environment.
|
This can also be used standalone for manually building a custom DMP environment.
|
||||||
Args:
|
Args:
|
||||||
@ -78,7 +79,7 @@ def make_bb(
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_verify_time_limit(traj_gen_kwargs.get("duration"), kwargs.get("time_limit"))
|
_verify_time_limit(traj_gen_kwargs.get("duration"), time_limit)
|
||||||
|
|
||||||
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
|
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
|
||||||
do_replanning = black_box_kwargs.get('replanning_schedule')
|
do_replanning = black_box_kwargs.get('replanning_schedule')
|
||||||
@ -93,7 +94,7 @@ def make_bb(
|
|||||||
if isinstance(env, str):
|
if isinstance(env, str):
|
||||||
env = make(env)
|
env = make(env)
|
||||||
|
|
||||||
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps, **kwargs)
|
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps)
|
||||||
|
|
||||||
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
|
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
|
||||||
if type(env.observation_space) == gym.spaces.dict.Dict:
|
if type(env.observation_space) == gym.spaces.dict.Dict:
|
||||||
@ -153,6 +154,34 @@ def get_env_duration(env: gym.Env):
|
|||||||
return duration
|
return duration
|
||||||
|
|
||||||
|
|
||||||
|
def make(env_id: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Converts an env_id to an environment with the gym API.
|
||||||
|
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
|
||||||
|
specified with "dmc/domain_name-task_name"
|
||||||
|
Analogously, metaworld tasks can be created as "metaworld/env_id-v2".
|
||||||
|
Args:
|
||||||
|
env_id: spec or env_id for gym tasks, external environments require a domain specification
|
||||||
|
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
|
||||||
|
Returns: Gym environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
if env_id.startswith('metaworld'):
|
||||||
|
env = make_metaworld(env_id.replace('metaworld', '')[1:], **kwargs)
|
||||||
|
|
||||||
|
env = gym_make(env_id, **kwargs)
|
||||||
|
|
||||||
|
if not env.spec.max_episode_steps == None:
|
||||||
|
# Hack: Some envs violate the gym spec in that they don't correctly expose the maximum episode steps
|
||||||
|
# Gymnasium disallows accessing private attributes, so we have to get creative to read the internal values
|
||||||
|
# TODO: Remove this, when all supported envs correctly implement this themselves
|
||||||
|
unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
|
||||||
|
if hasattr(unwrapped, '_max_episode_steps'):
|
||||||
|
env.spec.max_episode_steps = unwrapped.__getattribute__('_max_episode_steps')
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
||||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||||
|
Loading…
Reference in New Issue
Block a user