minor refactoring

This commit is contained in:
Dominik Moritz Roth 2023-07-20 10:33:39 +02:00
parent 30bafd7a4f
commit 9fa932d2bb

View File

@ -15,6 +15,7 @@ from gymnasium import make
import numpy as np import numpy as np
from gymnasium.envs.registration import register, registry from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import TimeLimit from gymnasium.wrappers import TimeLimit
from gymnasium import make as gym_make
from fancy_gym.utils.env_compatibility import EnvCompatibility from fancy_gym.utils.env_compatibility import EnvCompatibility
from fancy_gym.utils.wrappers import FlattenObservation from fancy_gym.utils.wrappers import FlattenObservation
@ -32,7 +33,7 @@ except Exception:
pass pass
def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs): def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None):
""" """
Helper function for creating a wrapped gym environment using MPs. Helper function for creating a wrapped gym environment using MPs.
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
@ -62,7 +63,7 @@ def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=
def make_bb( def make_bb(
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
fallback_max_steps: int = None, **kwargs): time_limit: int, fallback_max_steps: int = None):
""" """
This can also be used standalone for manually building a custom DMP environment. This can also be used standalone for manually building a custom DMP environment.
Args: Args:
@ -78,7 +79,7 @@ def make_bb(
Returns: DMP wrapped gym env Returns: DMP wrapped gym env
""" """
_verify_time_limit(traj_gen_kwargs.get("duration"), kwargs.get("time_limit")) _verify_time_limit(traj_gen_kwargs.get("duration"), time_limit)
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories') learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
do_replanning = black_box_kwargs.get('replanning_schedule') do_replanning = black_box_kwargs.get('replanning_schedule')
@ -93,7 +94,7 @@ def make_bb(
if isinstance(env, str): if isinstance(env, str):
env = make(env) env = make(env)
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps, **kwargs) env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps)
# BB expects a spaces.Box to be exposed, need to convert for dict-observations # BB expects a spaces.Box to be exposed, need to convert for dict-observations
if type(env.observation_space) == gym.spaces.dict.Dict: if type(env.observation_space) == gym.spaces.dict.Dict:
@ -153,6 +154,34 @@ def get_env_duration(env: gym.Env):
return duration return duration
def make(env_id: str, **kwargs):
"""
Converts an env_id to an environment with the gym API.
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
specified with "dmc/domain_name-task_name"
Analogously, metaworld tasks can be created as "metaworld/env_id-v2".
Args:
env_id: spec or env_id for gym tasks, external environments require a domain specification
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
Returns: Gym environment
"""
if env_id.startswith('metaworld'):
env = make_metaworld(env_id.replace('metaworld', '')[1:], **kwargs)
env = gym_make(env_id, **kwargs)
if not env.spec.max_episode_steps == None:
# Hack: Some envs violate the gym spec in that they don't correctly expose the maximum episode steps
# Gymnasium disallows accessing private attributes, so we have to get creative to read the internal values
# TODO: Remove this, when all supported envs correctly implement this themselves
unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
if hasattr(unwrapped, '_max_episode_steps'):
env.spec.max_episode_steps = unwrapped.__getattribute__('_max_episode_steps')
return env
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs): def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
if env_id not in metaworld.ML1.ENV_NAMES: if env_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.') raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')