Better doumentation of fancy registry fucntions (register & upgrade)
This commit is contained in:
parent
d6ecc0dc67
commit
56c1c65d09
@ -1,4 +1,4 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union, Callable, List, Dict, Any, Optional
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
@ -113,14 +113,41 @@ MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = {}
|
|||||||
|
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
id,
|
id: str,
|
||||||
entry_point=None,
|
entry_point: Optional[Union[Callable, str]] = None,
|
||||||
mp_wrapper=DefaultMPWrapper,
|
mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper,
|
||||||
register_step_based=True, # TODO: Detect
|
register_step_based: bool = True, # TODO: Detect
|
||||||
add_mp_types=KNOWN_MPS,
|
add_mp_types: List[str] = KNOWN_MPS,
|
||||||
mp_config_override={},
|
mp_config_override: Dict[str, Any] = {},
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Registers a Gymnasium environment, including Movement Primitives (MP) versions.
|
||||||
|
If you only want to register MP versions for an already registered environment, use fancy_gym.upgrade instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The unique identifier for the environment.
|
||||||
|
entry_point (Optional[Union[Callable, str]]): The entry point for creating the environment.
|
||||||
|
mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment.
|
||||||
|
register_step_based (bool): Whether to also register the raw srtep-based version of the environment (default True).
|
||||||
|
add_mp_types (List[str]): List of additional MP types to register.
|
||||||
|
mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration.
|
||||||
|
**kwargs: Additional keyword arguments which are passed to the environment constructor.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- When `register_step_based` is True, the raw environment will also be registered to gymnasium otherwise only mp-versions will be registered.
|
||||||
|
- `entry_point` can be given as a string, allowing the same notation as gymnasium.
|
||||||
|
- If `id` already exists in the Gymnasium registry and `register_step_based` is True,
|
||||||
|
a warning message will be printed, suggesting to set `register_step_based=False` or use `fancy_gym.upgrade`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To register a step-based environment with Movement Primitive versions (will use default mp_wrapper):
|
||||||
|
>>> register("MyEnv-v0", MyEnvClass"my_module:MyEnvClass")
|
||||||
|
|
||||||
|
The entry point can also be provided as a string:
|
||||||
|
>>> register("MyEnv-v0", "my_module:MyEnvClass")
|
||||||
|
|
||||||
|
"""
|
||||||
if register_step_based and id in gym_registry:
|
if register_step_based and id in gym_registry:
|
||||||
print(f'[Info] Gymnasium env with id "{id}" already exists. You should supply register_step_based=False or use fancy_gym.upgrade if you only want to register mp versions of an existing env.')
|
print(f'[Info] Gymnasium env with id "{id}" already exists. You should supply register_step_based=False or use fancy_gym.upgrade if you only want to register mp versions of an existing env.')
|
||||||
if register_step_based:
|
if register_step_based:
|
||||||
@ -135,23 +162,48 @@ def register(
|
|||||||
|
|
||||||
|
|
||||||
def upgrade(
|
def upgrade(
|
||||||
id,
|
id: str,
|
||||||
mp_wrapper=DefaultMPWrapper,
|
mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper,
|
||||||
add_mp_types=KNOWN_MPS,
|
add_mp_types: List[str] = KNOWN_MPS,
|
||||||
base_id=None,
|
base_id: Optional[str] = None,
|
||||||
mp_config_override={},
|
mp_config_override: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Upgrades an existing Gymnasium environment to include Movement Primitives (MP) versions.
|
||||||
|
We expect the raw step-based env to be already registered with gymnasium. Otherwise please use fancy_gym.register instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id (str): The unique identifier for the environment.
|
||||||
|
mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment (default is DefaultMPWrapper).
|
||||||
|
add_mp_types (List[str]): List of additional MP types to register (default is KNOWN_MPS).
|
||||||
|
base_id (Optional[str]): The unique identifier for the environment to upgrade. Will use id if non is provided. Can be defined to allow multiple registrations of different versions for the same step-based environment.
|
||||||
|
mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The `id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade. You can also pick a new one, but then `base_id` needs to be provided.
|
||||||
|
- The `mp_wrapper` parameter specifies the MP wrapper to use, allowing for customization.
|
||||||
|
- `add_mp_types` can be used to specify additional MP types to register alongside the base environment.
|
||||||
|
- The `base_id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade.
|
||||||
|
- `mp_config_override` allows for customizing MP configuration if needed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
To upgrade an existing environment with MP versions:
|
||||||
|
>>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper)
|
||||||
|
|
||||||
|
To upgrade an existing environment with custom MP types and configuration:
|
||||||
|
>>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper, add_mp_types=["ProDMP", "DMP"], mp_config_override={"param": 42})
|
||||||
|
"""
|
||||||
if not base_id:
|
if not base_id:
|
||||||
base_id = id
|
base_id = id
|
||||||
register_mps(id, base_id, mp_wrapper, add_mp_types, mp_config_override)
|
register_mps(id, base_id, mp_wrapper, add_mp_types, mp_config_override)
|
||||||
|
|
||||||
|
|
||||||
def register_mps(id, base_id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}):
|
def register_mps(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, add_mp_types: List[str] = KNOWN_MPS, mp_config_override: Dict[str, Any] = {}):
|
||||||
for mp_type in add_mp_types:
|
for mp_type in add_mp_types:
|
||||||
register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
|
register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override={}):
|
def register_mp(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, mp_type: List[str], mp_config_override: Dict[str, Any] = {}):
|
||||||
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
||||||
assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user