diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 5176699..c3eb896 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple, Union, Callable, List, Dict, Any, Optional import copy import importlib @@ -113,14 +113,41 @@ MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = {} def register( - id, - entry_point=None, - mp_wrapper=DefaultMPWrapper, - register_step_based=True, # TODO: Detect - add_mp_types=KNOWN_MPS, - mp_config_override={}, + id: str, + entry_point: Optional[Union[Callable, str]] = None, + mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper, + register_step_based: bool = True, # TODO: Detect + add_mp_types: List[str] = KNOWN_MPS, + mp_config_override: Dict[str, Any] = {}, **kwargs ): + """ + Registers a Gymnasium environment, including Movement Primitives (MP) versions. + If you only want to register MP versions for an already registered environment, use fancy_gym.upgrade instead. + + Args: + id (str): The unique identifier for the environment. + entry_point (Optional[Union[Callable, str]]): The entry point for creating the environment. + mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment. + register_step_based (bool): Whether to also register the raw srtep-based version of the environment (default True). + add_mp_types (List[str]): List of additional MP types to register. + mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration. + **kwargs: Additional keyword arguments which are passed to the environment constructor. + + Notes: + - When `register_step_based` is True, the raw environment will also be registered to gymnasium otherwise only mp-versions will be registered. + - `entry_point` can be given as a string, allowing the same notation as gymnasium. + - If `id` already exists in the Gymnasium registry and `register_step_based` is True, + a warning message will be printed, suggesting to set `register_step_based=False` or use `fancy_gym.upgrade`. + + Example: + To register a step-based environment with Movement Primitive versions (will use default mp_wrapper): + >>> register("MyEnv-v0", MyEnvClass"my_module:MyEnvClass") + + The entry point can also be provided as a string: + >>> register("MyEnv-v0", "my_module:MyEnvClass") + + """ if register_step_based and id in gym_registry: print(f'[Info] Gymnasium env with id "{id}" already exists. You should supply register_step_based=False or use fancy_gym.upgrade if you only want to register mp versions of an existing env.') if register_step_based: @@ -135,23 +162,48 @@ def register( def upgrade( - id, - mp_wrapper=DefaultMPWrapper, - add_mp_types=KNOWN_MPS, - base_id=None, - mp_config_override={}, + id: str, + mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper, + add_mp_types: List[str] = KNOWN_MPS, + base_id: Optional[str] = None, + mp_config_override: Dict[str, Any] = {}, ): + """ + Upgrades an existing Gymnasium environment to include Movement Primitives (MP) versions. + We expect the raw step-based env to be already registered with gymnasium. Otherwise please use fancy_gym.register instead. + + Args: + id (str): The unique identifier for the environment. + mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment (default is DefaultMPWrapper). + add_mp_types (List[str]): List of additional MP types to register (default is KNOWN_MPS). + base_id (Optional[str]): The unique identifier for the environment to upgrade. Will use id if non is provided. Can be defined to allow multiple registrations of different versions for the same step-based environment. + mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration. + + Notes: + - The `id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade. You can also pick a new one, but then `base_id` needs to be provided. + - The `mp_wrapper` parameter specifies the MP wrapper to use, allowing for customization. + - `add_mp_types` can be used to specify additional MP types to register alongside the base environment. + - The `base_id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade. + - `mp_config_override` allows for customizing MP configuration if needed. + + Example: + To upgrade an existing environment with MP versions: + >>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper) + + To upgrade an existing environment with custom MP types and configuration: + >>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper, add_mp_types=["ProDMP", "DMP"], mp_config_override={"param": 42}) + """ if not base_id: base_id = id register_mps(id, base_id, mp_wrapper, add_mp_types, mp_config_override) -def register_mps(id, base_id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}): +def register_mps(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, add_mp_types: List[str] = KNOWN_MPS, mp_config_override: Dict[str, Any] = {}): for mp_type in add_mp_types: register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {})) -def register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override={}): +def register_mp(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, mp_type: List[str], mp_config_override: Dict[str, Any] = {}): assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'