+from typing import Tuple, Union, Callable, List, Dict, Any, Optional
+
+import copy
+import importlib
+import numpy as np
+from collections import defaultdict
+
+from collections.abc import Mapping, MutableMapping
+
+from fancy_gym.utils.make_env_helpers import make_bb
+from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
+
+from gymnasium import register as gym_register
+from gymnasium import make as gym_make
+from gymnasium.envs.registration import registry as gym_registry
+
+
+class DefaultMPWrapper(RawInterfaceWrapper):
+ @property
+ def context_mask(self):
+ """
+ Returns boolean mask of the same shape as the observation space.
+ It determines whether the observation is returned for the contextual case or not.
+ This effectively allows to filter unwanted or unnecessary observations from the full step-based case.
+ E.g. Velocities starting at 0 are only changing after the first action. Given we only receive the
+ context/part of the first observation, the velocities are not necessary in the observation for the task.
+ Returns:
+ bool array representing the indices of the observations
+ """
+ # If the env already defines a context_mask, we will use that
+ if hasattr(self.env, 'context_mask'):
+ return self.env.context_mask
+
+ # Otherwise we will use the whole observation as the context. (Write a custom MPWrapper to change this behavior)
+ return np.full(self.env.observation_space.shape, True)
+
+ @property
+ def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
+ """
+ Returns the current position of the action/control dimension.
+ The dimensionality has to match the action/control dimension.
+ This is not required when exclusively using velocity control,
+ it should, however, be implemented regardless.
+ E.g. The joint positions that are directly or indirectly controlled by the action.
+ """
+ assert hasattr(self.env, 'current_pos'), 'DefaultMPWrapper was unable to access env.current_pos. Please write a custom MPWrapper (recommended) or expose this attribute directly.'
+ return self.env.current_pos
+
+ @property
+ def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
+ """
+ Returns the current velocity of the action/control dimension.
+ The dimensionality has to match the action/control dimension.
+ This is not required when exclusively using position control,
+ it should, however, be implemented regardless.
+ E.g. The joint velocities that are directly or indirectly controlled by the action.
+ """
+ assert hasattr(self.env, 'current_vel'), 'DefaultMPWrapper was unable to access env.current_vel. Please write a custom MPWrapper (recommended) or expose this attribute directly.'
+ return self.env.current_vel
+
+
+_BB_DEFAULTS = {
+ 'ProMP': {
+ 'wrappers': [],
+ 'trajectory_generator_kwargs': {
+ 'trajectory_generator_type': 'promp'
+ },
+ 'phase_generator_kwargs': {
+ 'phase_generator_type': 'linear'
+ },
+ 'controller_kwargs': {
+ 'controller_type': 'motor',
+ 'p_gains': 1.0,
+ 'd_gains': 0.1,
+ },
+ 'basis_generator_kwargs': {
+ 'basis_generator_type': 'zero_rbf',
+ 'num_basis': 5,
+ 'num_basis_zero_start': 1,
+ 'basis_bandwidth_factor': 3.0,
+ },
+ 'black_box_kwargs': {
+ }
+ },
+ 'DMP': {
+ 'wrappers': [],
+ 'trajectory_generator_kwargs': {
+ 'trajectory_generator_type': 'dmp'
+ },
+ 'phase_generator_kwargs': {
+ 'phase_generator_type': 'exp'
+ },
+ 'controller_kwargs': {
+ 'controller_type': 'motor',
+ 'p_gains': 1.0,
+ 'd_gains': 0.1,
+ },
+ 'basis_generator_kwargs': {
+ 'basis_generator_type': 'rbf',
+ 'num_basis': 5
+ },
+ 'black_box_kwargs': {
+ }
+ },
+ 'ProDMP': {
+ 'wrappers': [],
+ 'trajectory_generator_kwargs': {
+ 'trajectory_generator_type': 'prodmp',
+ 'duration': 2.0,
+ 'weights_scale': 1.0,
+ },
+ 'phase_generator_kwargs': {
+ 'phase_generator_type': 'exp',
+ 'tau': 1.5,
+ },
+ 'controller_kwargs': {
+ 'controller_type': 'motor',
+ 'p_gains': 1.0,
+ 'd_gains': 0.1,
+ },
+ 'basis_generator_kwargs': {
+ 'basis_generator_type': 'prodmp',
+ 'alpha': 10,
+ 'num_basis': 5,
+ },
+ 'black_box_kwargs': {
+ }
+ }
+}
+
+KNOWN_MPS = list(_BB_DEFAULTS.keys())
+_KNOWN_MPS_PLUS_ALL = KNOWN_MPS + ['all']
+ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in _KNOWN_MPS_PLUS_ALL}
+MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS = {}
+
+
+[docs]def register(
+
id: str,
+
entry_point: Optional[Union[Callable, str]] = None,
+
mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper,
+
register_step_based: bool = True, # TODO: Detect
+
add_mp_types: List[str] = KNOWN_MPS,
+
mp_config_override: Dict[str, Any] = {},
+
**kwargs
+
):
+
"""
+
Registers a Gymnasium environment, including Movement Primitives (MP) versions.
+
If you only want to register MP versions for an already registered environment, use fancy_gym.upgrade instead.
+
+
Args:
+
id (str): The unique identifier for the environment.
+
entry_point (Optional[Union[Callable, str]]): The entry point for creating the environment.
+
mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment.
+
register_step_based (bool): Whether to also register the raw srtep-based version of the environment (default True).
+
add_mp_types (List[str]): List of additional MP types to register.
+
mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration.
+
**kwargs: Additional keyword arguments which are passed to the environment constructor.
+
+
Notes:
+
- When `register_step_based` is True, the raw environment will also be registered to gymnasium otherwise only mp-versions will be registered.
+
- `entry_point` can be given as a string, allowing the same notation as gymnasium.
+
- If `id` already exists in the Gymnasium registry and `register_step_based` is True,
+
a warning message will be printed, suggesting to set `register_step_based=False` or use `fancy_gym.upgrade`.
+
+
Example:
+
To register a step-based environment with Movement Primitive versions (will use default mp_wrapper):
+
>>> register("MyEnv-v0", MyEnvClass"my_module:MyEnvClass")
+
+
The entry point can also be provided as a string:
+
>>> register("MyEnv-v0", "my_module:MyEnvClass")
+
+
"""
+
if register_step_based and id in gym_registry:
+
print(f'[Info] Gymnasium env with id "{id}" already exists. You should supply register_step_based=False or use fancy_gym.upgrade if you only want to register mp versions of an existing env.')
+
if register_step_based:
+
assert entry_point != None, 'You need to provide an entry-point, when registering step-based.'
+
if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point)
+
mod_name, attr_name = mp_wrapper.split(':')
+
mod = importlib.import_module(mod_name)
+
mp_wrapper = getattr(mod, attr_name)
+
if register_step_based:
+
gym_register(id=id, entry_point=entry_point, **kwargs)
+
upgrade(id, mp_wrapper, add_mp_types, mp_config_override)
+
+
+[docs]def upgrade(
+
id: str,
+
mp_wrapper: RawInterfaceWrapper = DefaultMPWrapper,
+
add_mp_types: List[str] = KNOWN_MPS,
+
base_id: Optional[str] = None,
+
mp_config_override: Dict[str, Any] = {},
+
):
+
"""
+
Upgrades an existing Gymnasium environment to include Movement Primitives (MP) versions.
+
We expect the raw step-based env to be already registered with gymnasium. Otherwise please use fancy_gym.register instead.
+
+
Args:
+
id (str): The unique identifier for the environment.
+
mp_wrapper (RawInterfaceWrapper): The MP wrapper for the environment (default is DefaultMPWrapper).
+
add_mp_types (List[str]): List of additional MP types to register (default is KNOWN_MPS).
+
base_id (Optional[str]): The unique identifier for the environment to upgrade. Will use id if non is provided. Can be defined to allow multiple registrations of different versions for the same step-based environment.
+
mp_config_override (Dict[str, Any]): Dictionary for overriding MP configuration.
+
+
Notes:
+
- The `id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade. You can also pick a new one, but then `base_id` needs to be provided.
+
- The `mp_wrapper` parameter specifies the MP wrapper to use, allowing for customization.
+
- `add_mp_types` can be used to specify additional MP types to register alongside the base environment.
+
- The `base_id` parameter should match the ID of the existing Gymnasium environment you wish to upgrade.
+
- `mp_config_override` allows for customizing MP configuration if needed.
+
+
Example:
+
To upgrade an existing environment with MP versions:
+
>>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper)
+
+
To upgrade an existing environment with custom MP types and configuration:
+
>>> upgrade("MyEnv-v0", mp_wrapper=CustomMPWrapper, add_mp_types=["ProDMP", "DMP"], mp_config_override={"param": 42})
+
"""
+
if not base_id:
+
base_id = id
+
register_mps(id, base_id, mp_wrapper, add_mp_types, mp_config_override)
+
+
+def register_mps(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, add_mp_types: List[str] = KNOWN_MPS, mp_config_override: Dict[str, Any] = {}):
+ for mp_type in add_mp_types:
+ register_mp(id, base_id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
+
+
+def register_mp(id: str, base_id: str, mp_wrapper: RawInterfaceWrapper, mp_type: List[str], mp_config_override: Dict[str, Any] = {}):
+ assert mp_type in KNOWN_MPS, 'Unknown mp_type'
+ assert id not in ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
+
+ parts = id.split('/')
+ if len(parts) == 1:
+ ns, name = 'gym', parts[0]
+ elif len(parts) == 2:
+ ns, name = parts[0], parts[1]
+ else:
+ raise ValueError('env id can not contain multiple "/".')
+
+ parts = name.split('-')
+ assert len(parts) >= 2 and parts[-1].startswith('v'), 'Malformed env id, must end in -v{int}.'
+
+ fancy_id = f'{ns}_{mp_type}/{name}'
+
+ gym_register(
+ id=fancy_id,
+ entry_point=bb_env_constructor,
+ kwargs={
+ 'underlying_id': base_id,
+ 'mp_wrapper': mp_wrapper,
+ 'mp_type': mp_type,
+ '_mp_config_override_register': mp_config_override
+ }
+ )
+
+ ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
+ ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['all'].append(fancy_id)
+ if ns not in MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS:
+ MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns] = {mp_type: [] for mp_type in _KNOWN_MPS_PLUS_ALL}
+ MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns][mp_type].append(fancy_id)
+ MOVEMENT_PRIMITIVE_ENVIRONMENTS_FOR_NS[ns]['all'].append(fancy_id)
+
+
+def nested_update(base: MutableMapping, update):
+ """
+ Updated method for nested Mappings
+ Args:
+ base: main Mapping to be updated
+ update: updated values for base Mapping
+
+ """
+ if any([item.endswith('_type') for item in update]):
+ base = update
+ return base
+ for k, v in update.items():
+ base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v
+ return base
+
+
+def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):
+ raw_underlying_env = gym_make(underlying_id, **kwargs)
+ underlying_env = mp_wrapper(raw_underlying_env)
+
+ mp_config = getattr(underlying_env, 'mp_config') if hasattr(underlying_env, 'mp_config') else {}
+ active_mp_config = copy.deepcopy(mp_config.get(mp_type, {}))
+ global_inherit_defaults = mp_config.get('inherit_defaults', True)
+ inherit_defaults = active_mp_config.pop('inherit_defaults', global_inherit_defaults)
+
+ config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {}
+ nested_update(config, active_mp_config)
+ nested_update(config, _mp_config_override_register)
+ nested_update(config, mp_config_override)
+
+ wrappers = config.pop('wrappers')
+
+ traj_gen_kwargs = config.pop('trajectory_generator_kwargs', {})
+ black_box_kwargs = config.pop('black_box_kwargs', {})
+ contr_kwargs = config.pop('controller_kwargs', {})
+ phase_kwargs = config.pop('phase_generator_kwargs', {})
+ basis_kwargs = config.pop('basis_generator_kwargs', {})
+
+ return make_bb(underlying_env,
+ wrappers=wrappers,
+ black_box_kwargs=black_box_kwargs,
+ traj_gen_kwargs=traj_gen_kwargs,
+ controller_kwargs=contr_kwargs,
+ phase_kwargs=phase_kwargs,
+ basis_kwargs=basis_kwargs,
+ **config)
+