From ae1033a18c4697ecc2c54d1f2c156837b11f2a48 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 29 Jul 2023 11:26:48 +0200 Subject: [PATCH] Remember mp-envs for each ns seperately (replicate legacy functionality) --- fancy_gym/envs/registry.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 3696fed..9cf2135 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -3,14 +3,15 @@ from typing import Tuple, Union import copy import importlib import numpy as np +from collections import defaultdict + from fancy_gym.utils.make_env_helpers import make_bb from fancy_gym.utils.utils import nested_update +from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper from gymnasium import register as gym_register from gymnasium import make as gym_make -from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper - class DefaultMPWrapper(RawInterfaceWrapper): @property @@ -104,6 +105,7 @@ _BB_DEFAULTS = { KNOWN_MPS = list(_BB_DEFAULTS.keys()) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS} +FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = defaultdict(lambda: {mp_type: [] for mp_type in KNOWN_MPS}) def register( @@ -152,9 +154,19 @@ def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}): def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' + + parts = id.split('/') + if len(parts) == 1: + ns, name = 'root', parts[0] + elif len(parts) == 2: + ns, name = parts[0], parts[1] + else: + raise ValueError('env id can not contain multiple "/".') + parts = id.split('-') assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.' fancy_id = '-'.join(parts[:-1]+[mp_type, parts[-1]]) + gym_register( id=fancy_id, entry_point=bb_env_constructor, @@ -165,7 +177,9 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): '_mp_config_override_register': mp_config_override } ) + ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) + FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id) def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):