metadata.mp_config now expected in MP_wrapper (implementing Fabian's feedback)

This commit is contained in:
Dominik Moritz Roth 2023-07-20 10:34:38 +02:00
parent 9fa932d2bb
commit f6e1718c1a

View File

@ -1,10 +1,30 @@
from typing import Tuple, Union
import copy
import importlib
import numpy as np
from fancy_gym.utils.make_env_helpers import make_bb from fancy_gym.utils.make_env_helpers import make_bb
from fancy_gym.utils.utils import nested_update from fancy_gym.utils.utils import nested_update
from gymnasium import register as gym_register from gymnasium import register as gym_register
from gymnasium import gym_make from gymnasium import gym_make
import copy from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
class DefaultMPWrapper(RawInterfaceWrapper):
@property
def context_mask(self):
return np.full(self.env.observation_space.shape, True)
@property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_pos
@property
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.current_vel
_BB_DEFAULTS = { _BB_DEFAULTS = {
'ProMP': { 'ProMP': {
@ -82,22 +102,26 @@ ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MP
def register( def register(
id, id,
entry_point, entry_point,
mp_wrapper=DefaultMPWrapper,
register_step_based=True, # TODO: Detect register_step_based=True, # TODO: Detect
add_mp_types=KNOWN_MPS, add_mp_types=KNOWN_MPS,
override_mp_config={},
**kwargs **kwargs
): ):
if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point)
mod_name, attr_name = mp_wrapper.split(":")
mod = importlib.import_module(mod_name)
mp_wrapper = getattr(mod, attr_name)
if register_step_based: if register_step_based:
gym_register(id=id, entry_point=entry_point, **kwargs) gym_register(id=id, entry_point=entry_point, **kwargs)
register_mps(id, override_mp_config, add_mp_types) register_mps(id, mp_wrapper, add_mp_types)
def register_mps(id, add_mp_types=KNOWN_MPS): def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS):
for mp_type in add_mp_types: for mp_type in add_mp_types:
register_mp(id, mp_type) register_mp(id, mp_wrapper, mp_type)
def register_mp(id, mp_type): def register_mp(id, mp_wrapper, mp_type):
assert mp_type in KNOWN_MPS, 'Unknown mp_type' assert mp_type in KNOWN_MPS, 'Unknown mp_type'
assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.' assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
parts = id.split('-') parts = id.split('-')
@ -108,14 +132,16 @@ def register_mp(id, mp_type):
entry_point=bb_env_constructor, entry_point=bb_env_constructor,
kwargs={ kwargs={
'underlying_id': id, 'underlying_id': id,
'mp_wrapper': mp_wrapper,
'mp_type': mp_type 'mp_type': mp_type
} }
) )
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id) ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs): def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs):
underlying_env = gym_make(underlying_id, **kwargs) raw_underlying_env = gym_make(underlying_id, **kwargs)
underlying_env = mp_wrapper(raw_underlying_env)
env_metadata = underlying_env.metadata env_metadata = underlying_env.metadata
metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {})) metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {}))
@ -134,8 +160,11 @@ def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs):
phase_kwargs = config.pop("phase_generator_kwargs", {}) phase_kwargs = config.pop("phase_generator_kwargs", {})
basis_kwargs = config.pop("basis_generator_kwargs", {}) basis_kwargs = config.pop("basis_generator_kwargs", {})
return make_bb(underlying_env, wrappers=wrappers, return make_bb(underlying_env,
wrappers=wrappers,
black_box_kwargs=black_box_kwargs, black_box_kwargs=black_box_kwargs,
traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, traj_gen_kwargs=traj_gen_kwargs,
controller_kwargs=contr_kwargs,
phase_kwargs=phase_kwargs, phase_kwargs=phase_kwargs,
basis_kwargs=basis_kwargs, **config) basis_kwargs=basis_kwargs,
**config)