Allow overriding mp_config during register and make (also better errors

for DefaultMPWrapper)
This commit is contained in:
Dominik Moritz Roth 2023-07-20 11:44:04 +02:00
parent 9d03542282
commit 17d370e2ba

View File

@ -15,14 +15,21 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
class DefaultMPWrapper(RawInterfaceWrapper): class DefaultMPWrapper(RawInterfaceWrapper):
@property @property
def context_mask(self): def context_mask(self):
# If the env already defines a context_mask, we will use that
if hasattr(self.env, 'context_mask'):
return self.env.context_mask
# Otherwise we will use the whole observation as the context. (Write a custom MPWrapper to change this behavior)
return np.full(self.env.observation_space.shape, True) return np.full(self.env.observation_space.shape, True)
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
assert hasattr(self.env, 'current_pos'), 'DefaultMPWrapper was unable to access env.current_pos. Please write a custom MPWrapper (recommended) or expose this attribute directly.'
return self.env.current_pos return self.env.current_pos
@property @property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
assert hasattr(self.env, 'current_vel'), 'DefaultMPWrapper was unable to access env.current_vel. Please write a custom MPWrapper (recommended) or expose this attribute directly.'
return self.env.current_vel return self.env.current_vel
@ -105,6 +112,7 @@ def register(
mp_wrapper=DefaultMPWrapper, mp_wrapper=DefaultMPWrapper,
register_step_based=True, # TODO: Detect register_step_based=True, # TODO: Detect
add_mp_types=KNOWN_MPS, add_mp_types=KNOWN_MPS,
mp_config_override={},
**kwargs **kwargs
): ):
if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point) if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point)
@ -113,15 +121,15 @@ def register(
mp_wrapper = getattr(mod, attr_name) mp_wrapper = getattr(mod, attr_name)
if register_step_based: if register_step_based:
gym_register(id=id, entry_point=entry_point, **kwargs) gym_register(id=id, entry_point=entry_point, **kwargs)
register_mps(id, mp_wrapper, add_mp_types) register_mps(id, mp_wrapper, add_mp_types, mp_config_override)
def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS): def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS, mp_config_override={}):
for mp_type in add_mp_types: for mp_type in add_mp_types:
register_mp(id, mp_wrapper, mp_type) register_mp(id, mp_wrapper, mp_type, mp_config_override.get(mp_type, {}))
def register_mp(id, mp_wrapper, mp_type): def register_mp(id, mp_wrapper, mp_type, mp_config_override={}):
assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert mp_type in KNOWN_MPS, 'Unknown mp_type'
assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
parts = id.split('-') parts = id.split('-')
@ -133,13 +141,14 @@ def register_mp(id, mp_wrapper, mp_type):
kwargs={ kwargs={
'underlying_id': id, 'underlying_id': id,
'mp_wrapper': mp_wrapper, 'mp_wrapper': mp_wrapper,
'mp_type': mp_type 'mp_type': mp_type,
'_mp_config_override_register': mp_config_override
} }
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs): def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, _mp_config_override_register={}, **kwargs):
raw_underlying_env = gym_make(underlying_id, **kwargs) raw_underlying_env = gym_make(underlying_id, **kwargs)
underlying_env = mp_wrapper(raw_underlying_env) underlying_env = mp_wrapper(raw_underlying_env)
@ -150,6 +159,7 @@ def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}
config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {} config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {}
nested_update(config, active_mp_config) nested_update(config, active_mp_config)
nested_update(config, _mp_config_override_register)
nested_update(config, mp_config_override) nested_update(config, mp_config_override)
wrappers = config.pop("wrappers") wrappers = config.pop("wrappers")