From 20d0be3c8d42de4aa8220813ed8d6ff7cb6f4aa5 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 30 Jul 2023 17:56:28 +0200 Subject: [PATCH] Replicate legacy behavior in exporting lists off all mp envs --- fancy_gym/__init__.py | 14 ++++++-------- fancy_gym/envs/registry.py | 12 ++++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fancy_gym/__init__.py b/fancy_gym/__init__.py index 77f245f..4e62ecf 100644 --- a/fancy_gym/__init__.py +++ b/fancy_gym/__init__.py @@ -1,12 +1,10 @@ from fancy_gym import dmc, meta, open_ai +from fancy_gym import envs as fancy from fancy_gym.utils.make_env_helpers import make_bb from .envs.registry import register, upgrade -from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS -# Convenience function for all MP environments -from .envs import ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS -from .meta import ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS -from .open_ai import ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS +from .envs.registry import ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS, MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS -ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = { - key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] + ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] + ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] - for key, value in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.items()} +ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS['dmc'] +ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS['fancy'] +ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS = MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS['metaworld'] +ALL_GYM_MOVEMENT_PRIMITIVE_ENVIRONMENTS = MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS['gym'] diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 41419ee..8016dc2 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -105,8 +105,8 @@ _BB_DEFAULTS = { } KNOWN_MPS = list(_BB_DEFAULTS.keys()) -ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS} -FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = defaultdict(lambda: {mp_type: [] for mp_type in KNOWN_MPS}) +ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS} +MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = defaultdict(lambda: {mp_type: [] for mp_type in KNOWN_MPS}) def register( @@ -156,11 +156,11 @@ def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}): def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): assert mp_type in KNOWN_MPS, 'Unknown mp_type' - assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' + assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' parts = id.split('/') if len(parts) == 1: - ns, name = 'root', parts[0] + ns, name = 'gym', parts[0] elif len(parts) == 2: ns, name = parts[0], parts[1] else: @@ -181,8 +181,8 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): } ) - ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) - FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id) + ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) + MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id) def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):