From 9d03542282f2e4267aabd796e731e028cae83180 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 20 Jul 2023 10:56:30 +0200 Subject: [PATCH] Move mp_config out of metadata and onto MPWrappers --- .../hole_reacher/hole_reacher.py | 28 ------------------- .../hole_reacher/mp_wrapper.py | 24 ++++++++++++++++ .../simple_reacher/mp_wrapper.py | 22 +++++++++++++++ .../simple_reacher/simple_reacher.py | 26 ----------------- .../viapoint_reacher/mp_wrapper.py | 20 +++++++++++++ .../viapoint_reacher/viapoint_reacher.py | 24 ---------------- fancy_gym/envs/registry.py | 10 +++---- 7 files changed, 71 insertions(+), 83 deletions(-) diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index 1fdf464..c9e0a61 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -15,34 +15,6 @@ MAX_EPISODE_STEPS_HOLEREACHER = 200 class HoleReacherEnv(BaseReacherDirectEnv): - metadata = { - 'mp_config': { - 'ProMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'controller_type': 'velocity', - }, - 'trajectory_generator_kwargs': { - 'weight_scale': 2, - }, - }, - 'DMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'controller_type': 'velocity', - }, - 'trajectory_generator_kwargs': { - # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check - 'weight_scale': 500, - }, - 'phase_generator_kwargs': { - 'alpha_phase': 2.5, - }, - }, - 'ProDMP': {}, - } - } - def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False, allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"): diff --git a/fancy_gym/envs/classic_control/hole_reacher/mp_wrapper.py b/fancy_gym/envs/classic_control/hole_reacher/mp_wrapper.py index d160b5c..c8e6dcc 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/mp_wrapper.py +++ b/fancy_gym/envs/classic_control/hole_reacher/mp_wrapper.py @@ -7,6 +7,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 2, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + }, + 'trajectory_generator_kwargs': { + # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check + 'weight_scale': 500, + }, + 'phase_generator_kwargs': { + 'alpha_phase': 2.5, + }, + }, + 'ProDMP': {}, + } + @property def context_mask(self): return np.hstack([ diff --git a/fancy_gym/envs/classic_control/simple_reacher/mp_wrapper.py b/fancy_gym/envs/classic_control/simple_reacher/mp_wrapper.py index 6d1fda1..2ee3cd1 100644 --- a/fancy_gym/envs/classic_control/simple_reacher/mp_wrapper.py +++ b/fancy_gym/envs/classic_control/simple_reacher/mp_wrapper.py @@ -7,6 +7,28 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'p_gains': 0.6, + 'd_gains': 0.075, + }, + }, + 'DMP': { + 'controller_kwargs': { + 'p_gains': 0.6, + 'd_gains': 0.075, + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 50, + }, + 'phase_generator_kwargs': { + 'alpha_phase': 2, + }, + }, + 'ProDMP': {}, + } + @property def context_mask(self): return np.hstack([ diff --git a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py index bb72848..5c63cf8 100644 --- a/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py +++ b/fancy_gym/envs/classic_control/simple_reacher/simple_reacher.py @@ -16,32 +16,6 @@ class SimpleReacherEnv(BaseReacherTorqueEnv): towards the end of the trajectory. """ - metadata = { - 'mp_config': { - 'ProMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'p_gains': 0.6, - 'd_gains': 0.075, - }, - }, - 'DMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'p_gains': 0.6, - 'd_gains': 0.075, - }, - 'trajectory_generator_kwargs': { - 'weight_scale': 50, - }, - 'phase_generator_kwargs': { - 'alpha_phase': 2, - }, - }, - 'ProDMP': {}, - } - } - def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True, allow_self_collision: bool = False, ): super().__init__(n_links, random_start, allow_self_collision) diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/mp_wrapper.py b/fancy_gym/envs/classic_control/viapoint_reacher/mp_wrapper.py index 47da749..c07b651 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/mp_wrapper.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/mp_wrapper.py @@ -7,6 +7,26 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper class MPWrapper(RawInterfaceWrapper): + mp_config = { + 'ProMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + }, + }, + 'DMP': { + 'controller_kwargs': { + 'controller_type': 'velocity', + }, + 'trajectory_generator_kwargs': { + 'weight_scale': 50, + }, + 'phase_generator_kwargs': { + 'alpha_phase': 2, + }, + }, + 'ProDMP': {}, + } + @property def context_mask(self): return np.hstack([ diff --git a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py index d0d04fb..febccc7 100644 --- a/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py +++ b/fancy_gym/envs/classic_control/viapoint_reacher/viapoint_reacher.py @@ -12,30 +12,6 @@ from . import MPWrapper class ViaPointReacherEnv(BaseReacherDirectEnv): - metadata = { - 'mp_config': { - 'ProMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'controller_type': 'velocity', - }, - }, - 'DMP': { - 'wrappers': [MPWrapper], - 'controller_kwargs': { - 'controller_type': 'velocity', - }, - 'trajectory_generator_kwargs': { - 'weight_scale': 50, - }, - 'phase_generator_kwargs': { - 'alpha_phase': 2, - }, - }, - 'ProDMP': {}, - } - } - def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None, target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): diff --git a/fancy_gym/envs/registry.py b/fancy_gym/envs/registry.py index 4d09acc..0172eaa 100644 --- a/fancy_gym/envs/registry.py +++ b/fancy_gym/envs/registry.py @@ -142,14 +142,14 @@ def register_mp(id, mp_wrapper, mp_type): def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs): raw_underlying_env = gym_make(underlying_id, **kwargs) underlying_env = mp_wrapper(raw_underlying_env) - env_metadata = underlying_env.metadata - metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {})) - global_inherit_defaults = env_metadata.get('mp_config', {}).get('inherit_defaults', True) - inherit_defaults = metadata_config.pop('inherit_defaults', global_inherit_defaults) + mp_config = underlying_env.get('mp_config', {}) + active_mp_config = copy.deepcopy(mp_config.get(mp_type, {})) + global_inherit_defaults = mp_config.get('inherit_defaults', True) + inherit_defaults = active_mp_config.pop('inherit_defaults', global_inherit_defaults) config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {} - nested_update(config, metadata_config) + nested_update(config, active_mp_config) nested_update(config, mp_config_override) wrappers = config.pop("wrappers")