From fb8f81afeaceaa5653a675931592b69879ede976 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 30 Jul 2023 18:26:45 +0200 Subject: [PATCH] Don't use defaultdicts for MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS (is ugly when exporting) --- fancy_gym/envs/registry.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 8016dc2..83ca1ac 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -105,8 +105,9 @@ _BB_DEFAULTS = { } KNOWN_MPS = list(_BB_DEFAULTS.keys()) -ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS} -MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = defaultdict(lambda: {mp_type: [] for mp_type in KNOWN_MPS}) +_KNOWN_MPS_PLUS_ALL = KNOWN_MPS + ['all'] +ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in _KNOWN_MPS_PLUS_ALL} +MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = {} def register( @@ -182,7 +183,11 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): ) ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) + ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['all'].append(fancy_id) + if ns not in MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS: + MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns] = {mp_type: [] for mp_type in _KNOWN_MPS_PLUS_ALL} MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id) + MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]['all'].append(fancy_id) def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):