metadata.mp_config now expected in MP_wrapper (implementing Fabian's feedback)
This commit is contained in:
parent
9fa932d2bb
commit
f6e1718c1a
@ -1,10 +1,30 @@
|
|||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
from fancy_gym.utils.make_env_helpers import make_bb
|
from fancy_gym.utils.make_env_helpers import make_bb
|
||||||
from fancy_gym.utils.utils import nested_update
|
from fancy_gym.utils.utils import nested_update
|
||||||
|
|
||||||
from gymnasium import register as gym_register
|
from gymnasium import register as gym_register
|
||||||
from gymnasium import gym_make
|
from gymnasium import gym_make
|
||||||
|
|
||||||
import copy
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultMPWrapper(RawInterfaceWrapper):
|
||||||
|
@property
|
||||||
|
def context_mask(self):
|
||||||
|
return np.full(self.env.observation_space.shape, True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return self.env.current_pos
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
return self.env.current_vel
|
||||||
|
|
||||||
|
|
||||||
_BB_DEFAULTS = {
|
_BB_DEFAULTS = {
|
||||||
'ProMP': {
|
'ProMP': {
|
||||||
@ -82,22 +102,26 @@ ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {mp_type: [] for mp_type in KNOWN_MP
|
|||||||
def register(
|
def register(
|
||||||
id,
|
id,
|
||||||
entry_point,
|
entry_point,
|
||||||
|
mp_wrapper=DefaultMPWrapper,
|
||||||
register_step_based=True, # TODO: Detect
|
register_step_based=True, # TODO: Detect
|
||||||
add_mp_types=KNOWN_MPS,
|
add_mp_types=KNOWN_MPS,
|
||||||
override_mp_config={},
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
if not callable(mp_wrapper): # mp_wrapper can be given as a String (same notation as for entry_point)
|
||||||
|
mod_name, attr_name = mp_wrapper.split(":")
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
mp_wrapper = getattr(mod, attr_name)
|
||||||
if register_step_based:
|
if register_step_based:
|
||||||
gym_register(id=id, entry_point=entry_point, **kwargs)
|
gym_register(id=id, entry_point=entry_point, **kwargs)
|
||||||
register_mps(id, override_mp_config, add_mp_types)
|
register_mps(id, mp_wrapper, add_mp_types)
|
||||||
|
|
||||||
|
|
||||||
def register_mps(id, add_mp_types=KNOWN_MPS):
|
def register_mps(id, mp_wrapper, add_mp_types=KNOWN_MPS):
|
||||||
for mp_type in add_mp_types:
|
for mp_type in add_mp_types:
|
||||||
register_mp(id, mp_type)
|
register_mp(id, mp_wrapper, mp_type)
|
||||||
|
|
||||||
|
|
||||||
def register_mp(id, mp_type):
|
def register_mp(id, mp_wrapper, mp_type):
|
||||||
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
assert mp_type in KNOWN_MPS, 'Unknown mp_type'
|
||||||
assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
assert id not in ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type], f'The environment {id} is already registered for {mp_type}.'
|
||||||
parts = id.split('-')
|
parts = id.split('-')
|
||||||
@ -108,14 +132,16 @@ def register_mp(id, mp_type):
|
|||||||
entry_point=bb_env_constructor,
|
entry_point=bb_env_constructor,
|
||||||
kwargs={
|
kwargs={
|
||||||
'underlying_id': id,
|
'underlying_id': id,
|
||||||
|
'mp_wrapper': mp_wrapper,
|
||||||
'mp_type': mp_type
|
'mp_type': mp_type
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
|
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS[mp_type].append(fancy_id)
|
||||||
|
|
||||||
|
|
||||||
def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs):
|
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs):
|
||||||
underlying_env = gym_make(underlying_id, **kwargs)
|
raw_underlying_env = gym_make(underlying_id, **kwargs)
|
||||||
|
underlying_env = mp_wrapper(raw_underlying_env)
|
||||||
env_metadata = underlying_env.metadata
|
env_metadata = underlying_env.metadata
|
||||||
|
|
||||||
metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {}))
|
metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {}))
|
||||||
@ -134,8 +160,11 @@ def bb_env_constructor(underlying_id, mp_type, mp_config_override={}, **kwargs):
|
|||||||
phase_kwargs = config.pop("phase_generator_kwargs", {})
|
phase_kwargs = config.pop("phase_generator_kwargs", {})
|
||||||
basis_kwargs = config.pop("basis_generator_kwargs", {})
|
basis_kwargs = config.pop("basis_generator_kwargs", {})
|
||||||
|
|
||||||
return make_bb(underlying_env, wrappers=wrappers,
|
return make_bb(underlying_env,
|
||||||
|
wrappers=wrappers,
|
||||||
black_box_kwargs=black_box_kwargs,
|
black_box_kwargs=black_box_kwargs,
|
||||||
traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs,
|
traj_gen_kwargs=traj_gen_kwargs,
|
||||||
|
controller_kwargs=contr_kwargs,
|
||||||
phase_kwargs=phase_kwargs,
|
phase_kwargs=phase_kwargs,
|
||||||
basis_kwargs=basis_kwargs, **config)
|
basis_kwargs=basis_kwargs,
|
||||||
|
**config)
|
||||||
|
Loading…
Reference in New Issue
Block a user