Move mp_config out of metadata and onto MPWrappers
This commit is contained in:
		
							parent
							
								
									f6e1718c1a
								
							
						
					
					
						commit
						9d03542282
					
				@ -15,34 +15,6 @@ MAX_EPISODE_STEPS_HOLEREACHER = 200
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class HoleReacherEnv(BaseReacherDirectEnv):
 | 
					class HoleReacherEnv(BaseReacherDirectEnv):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    metadata = {
 | 
					 | 
				
			||||||
        'mp_config': {
 | 
					 | 
				
			||||||
            'ProMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'controller_type': 'velocity',
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'trajectory_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'weight_scale': 2,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'DMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'controller_type': 'velocity',
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'trajectory_generator_kwargs': {
 | 
					 | 
				
			||||||
                    # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
 | 
					 | 
				
			||||||
                    'weight_scale': 500,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'phase_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'alpha_phase': 2.5,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'ProDMP': {},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
 | 
					    def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
 | 
				
			||||||
                 hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
 | 
					                 hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
 | 
				
			||||||
                 allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
 | 
					                 allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class MPWrapper(RawInterfaceWrapper):
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mp_config = {
 | 
				
			||||||
 | 
					        'ProMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'controller_type': 'velocity',
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'trajectory_generator_kwargs': {
 | 
				
			||||||
 | 
					                'weight_scale': 2,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'DMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'controller_type': 'velocity',
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'trajectory_generator_kwargs': {
 | 
				
			||||||
 | 
					                # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
 | 
				
			||||||
 | 
					                'weight_scale': 500,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'phase_generator_kwargs': {
 | 
				
			||||||
 | 
					                'alpha_phase': 2.5,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'ProDMP': {},
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def context_mask(self):
 | 
					    def context_mask(self):
 | 
				
			||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,28 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class MPWrapper(RawInterfaceWrapper):
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mp_config = {
 | 
				
			||||||
 | 
					        'ProMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'p_gains': 0.6,
 | 
				
			||||||
 | 
					                'd_gains': 0.075,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'DMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'p_gains': 0.6,
 | 
				
			||||||
 | 
					                'd_gains': 0.075,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'trajectory_generator_kwargs': {
 | 
				
			||||||
 | 
					                'weight_scale': 50,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'phase_generator_kwargs': {
 | 
				
			||||||
 | 
					                'alpha_phase': 2,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'ProDMP': {},
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def context_mask(self):
 | 
					    def context_mask(self):
 | 
				
			||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
				
			|||||||
@ -16,32 +16,6 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
 | 
				
			|||||||
    towards the end of the trajectory.
 | 
					    towards the end of the trajectory.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    metadata = {
 | 
					 | 
				
			||||||
        'mp_config': {
 | 
					 | 
				
			||||||
            'ProMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'p_gains': 0.6,
 | 
					 | 
				
			||||||
                    'd_gains': 0.075,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'DMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'p_gains': 0.6,
 | 
					 | 
				
			||||||
                    'd_gains': 0.075,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'trajectory_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'weight_scale': 50,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'phase_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'alpha_phase': 2,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'ProDMP': {},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
 | 
					    def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
 | 
				
			||||||
                 allow_self_collision: bool = False, ):
 | 
					                 allow_self_collision: bool = False, ):
 | 
				
			||||||
        super().__init__(n_links, random_start, allow_self_collision)
 | 
					        super().__init__(n_links, random_start, allow_self_collision)
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,26 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class MPWrapper(RawInterfaceWrapper):
 | 
					class MPWrapper(RawInterfaceWrapper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mp_config = {
 | 
				
			||||||
 | 
					        'ProMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'controller_type': 'velocity',
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'DMP': {
 | 
				
			||||||
 | 
					            'controller_kwargs': {
 | 
				
			||||||
 | 
					                'controller_type': 'velocity',
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'trajectory_generator_kwargs': {
 | 
				
			||||||
 | 
					                'weight_scale': 50,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'phase_generator_kwargs': {
 | 
				
			||||||
 | 
					                'alpha_phase': 2,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        'ProDMP': {},
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def context_mask(self):
 | 
					    def context_mask(self):
 | 
				
			||||||
        return np.hstack([
 | 
					        return np.hstack([
 | 
				
			||||||
 | 
				
			|||||||
@ -12,30 +12,6 @@ from . import MPWrapper
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
 | 
					class ViaPointReacherEnv(BaseReacherDirectEnv):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    metadata = {
 | 
					 | 
				
			||||||
        'mp_config': {
 | 
					 | 
				
			||||||
            'ProMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'controller_type': 'velocity',
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'DMP': {
 | 
					 | 
				
			||||||
                'wrappers': [MPWrapper],
 | 
					 | 
				
			||||||
                'controller_kwargs': {
 | 
					 | 
				
			||||||
                    'controller_type': 'velocity',
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'trajectory_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'weight_scale': 50,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                'phase_generator_kwargs': {
 | 
					 | 
				
			||||||
                    'alpha_phase': 2,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'ProDMP': {},
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
 | 
					    def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
 | 
				
			||||||
                 target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
 | 
					                 target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -142,14 +142,14 @@ def register_mp(id, mp_wrapper, mp_type):
 | 
				
			|||||||
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs):
 | 
					def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs):
 | 
				
			||||||
    raw_underlying_env = gym_make(underlying_id, **kwargs)
 | 
					    raw_underlying_env = gym_make(underlying_id, **kwargs)
 | 
				
			||||||
    underlying_env = mp_wrapper(raw_underlying_env)
 | 
					    underlying_env = mp_wrapper(raw_underlying_env)
 | 
				
			||||||
    env_metadata = underlying_env.metadata
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {}))
 | 
					    mp_config = underlying_env.get('mp_config', {})
 | 
				
			||||||
    global_inherit_defaults = env_metadata.get('mp_config', {}).get('inherit_defaults', True)
 | 
					    active_mp_config = copy.deepcopy(mp_config.get(mp_type, {}))
 | 
				
			||||||
    inherit_defaults = metadata_config.pop('inherit_defaults', global_inherit_defaults)
 | 
					    global_inherit_defaults = mp_config.get('inherit_defaults', True)
 | 
				
			||||||
 | 
					    inherit_defaults = active_mp_config.pop('inherit_defaults', global_inherit_defaults)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {}
 | 
					    config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {}
 | 
				
			||||||
    nested_update(config, metadata_config)
 | 
					    nested_update(config, active_mp_config)
 | 
				
			||||||
    nested_update(config, mp_config_override)
 | 
					    nested_update(config, mp_config_override)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    wrappers = config.pop("wrappers")
 | 
					    wrappers = config.pop("wrappers")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user