diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index e7c2f09..4d09acc 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -1,10 +1,30 @@ +from typing import Tuple, Union + +import copy +import importlib +import numpy as np from fancy_gym.utils.make_env_helpers import make_bb from fancy_gym.utils.utils import nested_update from gymnasium import register as gym_register from gymnasium import gym_make -import copy +from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper + + +class DefaultMPWrapper(RawInterfaceWrapper): + @property + def context_mask(self): + return np.full(self.env.observation_space.shape, True) + + @property + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_pos + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.current_vel + _BB_DEFAULTS = { 'ProMP': { @@ -82,22 +102,26 @@ ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MP def register( id, entry_point, + mp_wrapper=DefaultMPWrapper, register_step_based=True, # TODO: Detect add_mp_types=KNOWN_MPS, - override_mp_config={}, **kwargs ): + if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point) + mod_name, attr_name = mp_wrapper.split(":") + mod = importlib.import_module(mod_name) + mp_wrapper = getattr(mod, attr_name) if register_step_based: gym_register(id=id, entry_point=entry_point, **kwargs) - register_mps(id, override_mp_config, add_mp_types) + register_mps(id, mp_wrapper, add_mp_types) -def register_mps(id, add_mp_types=KNOWN_MPS): +def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS): for mp_type in add_mp_types: - register_mp(id, mp_type) + register_mp(id, mp_wrapper, mp_type) -def register_mp(id, mp_type): +def register_mp(id, mp_wrapper, mp_type): assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' parts = id.split('-') @@ -108,14 +132,16 @@ def register_mp(id, mp_type): entry_point=bb_env_constructor, kwargs={ 'underlying_id': id, + 'mp_wrapper': mp_wrapper, 'mp_type': mp_type } ) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) -def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs): - underlying_env = gym_make(underlying_id, **kwargs) +def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs): + raw_underlying_env = gym_make(underlying_id, **kwargs) + underlying_env = mp_wrapper(raw_underlying_env) env_metadata = underlying_env.metadata metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {})) @@ -134,8 +160,11 @@ def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs): phase_kwargs = config.pop("phase_generator_kwargs", {}) basis_kwargs = config.pop("basis_generator_kwargs", {}) - return make_bb(underlying_env, wrappers=wrappers, + return make_bb(underlying_env, + wrappers=wrappers, black_box_kwargs=black_box_kwargs, - traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, + traj_gen_kwargs=traj_gen_kwargs, + controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs, - basis_kwargs=basis_kwargs, **config) + basis_kwargs=basis_kwargs, + **config)