diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 0172eaa..9e37fcc 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -15,14 +15,21 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class DefaultMPWrapper(RawInterfaceWrapper): @property def context_mask(self): + # If the env already defines a context_mask, we will use that + if hasattr(self.env, 'context_mask'): + return self.env.context_mask + + # Otherwise we will use the whole observation as the context. (Write a custom MPWrapper to change this behavior) return np.full(self.env.observation_space.shape, True) @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + assert hasattr(self.env, 'current_pos'), 'DefaultMPWrapper was unable to access env.current_pos. Please write a custom MPWrapper (recommended) or expose this attribute directly.' return self.env.current_pos @property def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + assert hasattr(self.env, 'current_vel'), 'DefaultMPWrapper was unable to access env.current_vel. Please write a custom MPWrapper (recommended) or expose this attribute directly.' return self.env.current_vel @@ -105,6 +112,7 @@ def register( mp_wrapper=DefaultMPWrapper, register_step_based=True, # TODO: Detect add_mp_types=KNOWN_MPS, + mp_config_override={}, **kwargs ): if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point) @@ -113,15 +121,15 @@ def register( mp_wrapper = getattr(mod, attr_name) if register_step_based: gym_register(id=id, entry_point=entry_point, **kwargs) - register_mps(id, mp_wrapper, add_mp_types) + register_mps(id, mp_wrapper, add_mp_types, mp_config_override) -def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS): +def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}): for mp_type in add_mp_types: - register_mp(id, mp_wrapper, mp_type) + register_mp(id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {})) -def register_mp(id, mp_wrapper, mp_type): +def register_mp(id, mp_wrapper, mp_type, mp_config_override={}): assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' parts = id.split('-') @@ -133,13 +141,14 @@ def register_mp(id, mp_wrapper, mp_type): kwargs={ 'underlying_id': id, 'mp_wrapper': mp_wrapper, - 'mp_type': mp_type + 'mp_type': mp_type, + '_mp_config_override_register': mp_config_override } ) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) -def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs): +def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs): raw_underlying_env = gym_make(underlying_id, **kwargs) underlying_env = mp_wrapper(raw_underlying_env) @@ -150,6 +159,7 @@ def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={} config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {} nested_update(config, active_mp_config) + nested_update(config, _mp_config_override_register) nested_update(config, mp_config_override) wrappers = config.pop("wrappers")