minor refactoring

This commit is contained in:
Dominik Moritz Roth 2023-07-20 10:33:39 +02:00
parent 30bafd7a4f
commit 9fa932d2bb

View File

@ -15,6 +15,7 @@ from gymnasium import make
import numpy as np
from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import TimeLimit
from gymnasium import make as gym_make
from fancy_gym.utils.env_compatibility import EnvCompatibility
from fancy_gym.utils.wrappers import FlattenObservation
@ -32,7 +33,7 @@ except Exception:
pass
def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs):
def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None):
"""
Helper function for creating a wrapped gym environment using MPs.
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
@ -62,7 +63,7 @@ def _make_wrapped_env(env: gym.Env, wrappers: Iterable[Type[gym.Wrapper]], seed=
def make_bb(
env: Union[gym.Env, str], wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping,
fallback_max_steps: int = None, **kwargs):
time_limit: int, fallback_max_steps: int = None):
"""
This can also be used standalone for manually building a custom DMP environment.
Args:
@ -78,7 +79,7 @@ def make_bb(
Returns: DMP wrapped gym env
"""
_verify_time_limit(traj_gen_kwargs.get("duration"), kwargs.get("time_limit"))
_verify_time_limit(traj_gen_kwargs.get("duration"), time_limit)
learn_sub_trajs = black_box_kwargs.get('learn_sub_trajectories')
do_replanning = black_box_kwargs.get('replanning_schedule')
@ -93,7 +94,7 @@ def make_bb(
if isinstance(env, str):
env = make(env)
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps, **kwargs)
env = _make_wrapped_env(env=env, wrappers=wrappers, fallback_max_steps=fallback_max_steps)
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
if type(env.observation_space) == gym.spaces.dict.Dict:
@ -153,6 +154,34 @@ def get_env_duration(env: gym.Env):
return duration
def make(env_id: str, **kwargs):
"""
Converts an env_id to an environment with the gym API.
This also works for DeepMind Control Suite environments that are wrapped using the DMCWrapper, they can be
specified with "dmc/domain_name-task_name"
Analogously, metaworld tasks can be created as "metaworld/env_id-v2".
Args:
env_id: spec or env_id for gym tasks, external environments require a domain specification
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
Returns: Gym environment
"""
if env_id.startswith('metaworld'):
env = make_metaworld(env_id.replace('metaworld', '')[1:], **kwargs)
env = gym_make(env_id, **kwargs)
if not env.spec.max_episode_steps == None:
# Hack: Some envs violate the gym spec in that they don't correctly expose the maximum episode steps
# Gymnasium disallows accessing private attributes, so we have to get creative to read the internal values
# TODO: Remove this, when all supported envs correctly implement this themselves
unwrapped = env.unwrapped if hasattr(env, 'unwrapped') else env
if hasattr(unwrapped, '_max_episode_steps'):
env.spec.max_episode_steps = unwrapped.__getattribute__('_max_episode_steps')
return env
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
if env_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')