Remember mp-envs for each ns seperately (replicate legacy functionality)

This commit is contained in:
Dominik Moritz Roth 2023-07-29 11:26:48 +02:00
parent 2fc44667c6
commit ae1033a18c

View File

@ -3,14 +3,15 @@ from typing import Tuple, Union
import copy import copy
import importlib import importlib
import numpy as np import numpy as np
from collections import defaultdict
from fancy_gym.utils.make_env_helpers import make_bb from fancy_gym.utils.make_env_helpers import make_bb
from fancy_gym.utils.utils import nested_update from fancy_gym.utils.utils import nested_update
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from gymnasium import register as gym_register from gymnasium import register as gym_register
from gymnasium import make as gym_make from gymnasium import make as gym_make
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
class DefaultMPWrapper(RawInterfaceWrapper): class DefaultMPWrapper(RawInterfaceWrapper):
@property @property
@ -104,6 +105,7 @@ _BB_DEFAULTS = {
KNOWN_MPS = list(_BB_DEFAULTS.keys()) KNOWN_MPS = list(_BB_DEFAULTS.keys())
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS} ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MPS}
FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = defaultdict(lambda: {mp_type: [] for mp_type in KNOWN_MPS})
def register( def register(
@ -152,9 +154,19 @@ def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}):
def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert mp_type in KNOWN_MPS, 'Unknown mp_type'
assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
parts = id.split('/')
if len(parts) == 1:
ns, name = 'root', parts[0]
elif len(parts) == 2:
ns, name = parts[0], parts[1]
else:
raise ValueError('env id can not contain multiple "/".')
parts = id.split('-') parts = id.split('-')
assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.' assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.'
fancy_id = '-'.join(parts[:-1]+[mp_type, parts[-1]]) fancy_id = '-'.join(parts[:-1]+[mp_type, parts[-1]])
gym_register( gym_register(
id=fancy_id, id=fancy_id,
entry_point=bb_env_constructor, entry_point=bb_env_constructor,
@ -165,7 +177,9 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
'_mp_config_override_register': mp_config_override '_mp_config_override_register': mp_config_override
} }
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id)
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs): def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):