Quickfix for mp_config merging and allow defining different base_id for
upgrades
This commit is contained in:
parent
8b3d05aaaf
commit
20b1b0ccac
@ -2,7 +2,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium import register as gym_register
|
from gymnasium import register as gym_register
|
||||||
from .registry import register
|
from .registry import register, upgrade
|
||||||
|
|
||||||
from . import classic_control, mujoco
|
from . import classic_control, mujoco
|
||||||
from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
from .classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
||||||
@ -213,12 +213,10 @@ for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
|
|||||||
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
upgrade(
|
||||||
id='fancy/BoxPushing{}Replan-v0'.format(reward_type),
|
id='fancy/BoxPushing{}Replan-v0'.format(reward_type),
|
||||||
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
base_id='fancy/BoxPushing{}-v0'.format(reward_type),
|
||||||
mp_wrapper=mujoco.box_pushing.ReplanMPWrapper,
|
mp_wrapper=mujoco.box_pushing.ReplanMPWrapper,
|
||||||
register_step_based=False,
|
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Table Tennis environments
|
# Table Tennis environments
|
||||||
|
@ -5,8 +5,10 @@ import importlib
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from collections.abc import Mapping, MutableMapping
|
||||||
|
|
||||||
from fancy_gym.utils.make_env_helpers import make_bb
|
from fancy_gym.utils.make_env_helpers import make_bb
|
||||||
from fancy_gym.utils.utils import nested_update
|
# from fancy_gym.utils.utils import nested_update
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
from gymnasium import register as gym_register
|
from gymnasium import register as gym_register
|
||||||
@ -129,33 +131,27 @@ def register(
|
|||||||
mp_wrapper = getattr(mod, attr_name)
|
mp_wrapper = getattr(mod, attr_name)
|
||||||
if register_step_based:
|
if register_step_based:
|
||||||
gym_register(id=id, entry_point=entry_point, **kwargs)
|
gym_register(id=id, entry_point=entry_point, **kwargs)
|
||||||
register_mps(id, mp_wrapper, add_mp_types, mp_config_override)
|
upgrade(id, mp_wrapper, add_mp_types, mp_config_override)
|
||||||
|
|
||||||
|
|
||||||
def upgrade(
|
def upgrade(
|
||||||
id,
|
id,
|
||||||
mp_wrapper=DefaultMPWrapper,
|
mp_wrapper=DefaultMPWrapper,
|
||||||
add_mp_types=KNOWN_MPS,
|
add_mp_types=KNOWN_MPS,
|
||||||
|
base_id=None,
|
||||||
mp_config_override={},
|
mp_config_override={},
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
register(
|
if not base_id:
|
||||||
id,
|
base_id = id
|
||||||
entry_point=None,
|
register_mps(id, base_id, mp_wrapper, add_mp_types, mp_config_override)
|
||||||
mp_wrapper=mp_wrapper,
|
|
||||||
register_step_based=False,
|
|
||||||
add_mp_types=add_mp_types,
|
|
||||||
mp_config_override=mp_config_override,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}):
|
def register_mps(id, base_id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}):
|
||||||
for mp_type in add_mp_types:
|
for mp_type in add_mp_types:
|
||||||
register_mp(id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
|
register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
|
def register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override={}):
|
||||||
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
||||||
assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
||||||
|
|
||||||
@ -176,7 +172,7 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
|
|||||||
id=fancy_id,
|
id=fancy_id,
|
||||||
entry_point=bb_env_constructor,
|
entry_point=bb_env_constructor,
|
||||||
kwargs={
|
kwargs={
|
||||||
'underlying_id': id,
|
'underlying_id': base_id,
|
||||||
'mp_wrapper': mp_wrapper,
|
'mp_wrapper': mp_wrapper,
|
||||||
'mp_type': mp_type,
|
'mp_type': mp_type,
|
||||||
'_mp_config_override_register': mp_config_override
|
'_mp_config_override_register': mp_config_override
|
||||||
@ -190,6 +186,24 @@ def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
|
|||||||
MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id)
|
MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id)
|
||||||
MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]['all'].append(fancy_id)
|
MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]['all'].append(fancy_id)
|
||||||
|
|
||||||
|
# TODO: Apply inherit_defaults: False to appropiate places and remove this...
|
||||||
|
|
||||||
|
|
||||||
|
def nested_update(base: MutableMapping, update):
|
||||||
|
"""
|
||||||
|
Updated method for nested Mappings
|
||||||
|
Args:
|
||||||
|
base: main Mapping to be updated
|
||||||
|
update: updated values for base Mapping
|
||||||
|
|
||||||
|
"""
|
||||||
|
if any([item.endswith('_type') for item in update]):
|
||||||
|
base = update
|
||||||
|
return base
|
||||||
|
for k, v in update.items():
|
||||||
|
base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):
|
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):
|
||||||
raw_underlying_env = gym_make(underlying_id, **kwargs)
|
raw_underlying_env = gym_make(underlying_id, **kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user