diff --git a/fancy_gym/__init__.py b/fancy_gym/__init__.py index 32308fa..77f245f 100644 --- a/fancy_gym/__init__.py +++ b/fancy_gym/__init__.py @@ -1,6 +1,6 @@ from fancy_gym import dmc, meta, open_ai from fancy_gym.utils.make_env_helpers import make_bb -from .envs.registry import register +from .envs.registry import register, upgrade from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS # Convenience function for all MP environments from .envs import ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 7580e47..3696fed 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -126,6 +126,24 @@ def register( register_mps(id, mp_wrapper, add_mp_types, mp_config_override) +def upgrade( + id, + mp_wrapper=DefaultMPWrapper, + add_mp_types=KNOWN_MPS, + mp_config_override={}, + **kwargs +): + register( + id, + entry_point=None, + mp_wrapper=mp_wrapper, + register_step_based=False, + add_mp_types=add_mp_types, + mp_config_override={}, + **kwargs + ) + + def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}): for mp_type in add_mp_types: register_mp(id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))