Move mp_config out of metadata and onto MPWrappers
This commit is contained in:
		
							parent
							
								
									f6e1718c1a
								
							
						
					
					
						commit
						9d03542282
					
				@ -15,34 +15,6 @@ MAX_EPISODE_STEPS_HOLEREACHER = 200
 | 
			
		||||
 | 
			
		||||
class HoleReacherEnv(BaseReacherDirectEnv):
 | 
			
		||||
 | 
			
		||||
    metadata = {
 | 
			
		||||
        'mp_config': {
 | 
			
		||||
            'ProMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'controller_type': 'velocity',
 | 
			
		||||
                },
 | 
			
		||||
                'trajectory_generator_kwargs': {
 | 
			
		||||
                    'weight_scale': 2,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'DMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'controller_type': 'velocity',
 | 
			
		||||
                },
 | 
			
		||||
                'trajectory_generator_kwargs': {
 | 
			
		||||
                    # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
 | 
			
		||||
                    'weight_scale': 500,
 | 
			
		||||
                },
 | 
			
		||||
                'phase_generator_kwargs': {
 | 
			
		||||
                    'alpha_phase': 2.5,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'ProDMP': {},
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
 | 
			
		||||
                 hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
 | 
			
		||||
                 allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,30 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
 | 
			
		||||
    mp_config = {
 | 
			
		||||
        'ProMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'controller_type': 'velocity',
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                'weight_scale': 2,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'DMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'controller_type': 'velocity',
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                # TODO: Before it was weight scale 50 and goal scale 0.1. We now only have weight scale and thus set it to 500. Check
 | 
			
		||||
                'weight_scale': 500,
 | 
			
		||||
            },
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'alpha_phase': 2.5,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'ProDMP': {},
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,28 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
 | 
			
		||||
    mp_config = {
 | 
			
		||||
        'ProMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'p_gains': 0.6,
 | 
			
		||||
                'd_gains': 0.075,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'DMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'p_gains': 0.6,
 | 
			
		||||
                'd_gains': 0.075,
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                'weight_scale': 50,
 | 
			
		||||
            },
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'alpha_phase': 2,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'ProDMP': {},
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
 | 
			
		||||
@ -16,32 +16,6 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
 | 
			
		||||
    towards the end of the trajectory.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    metadata = {
 | 
			
		||||
        'mp_config': {
 | 
			
		||||
            'ProMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'p_gains': 0.6,
 | 
			
		||||
                    'd_gains': 0.075,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'DMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'p_gains': 0.6,
 | 
			
		||||
                    'd_gains': 0.075,
 | 
			
		||||
                },
 | 
			
		||||
                'trajectory_generator_kwargs': {
 | 
			
		||||
                    'weight_scale': 50,
 | 
			
		||||
                },
 | 
			
		||||
                'phase_generator_kwargs': {
 | 
			
		||||
                    'alpha_phase': 2,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'ProDMP': {},
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
 | 
			
		||||
                 allow_self_collision: bool = False, ):
 | 
			
		||||
        super().__init__(n_links, random_start, allow_self_collision)
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,26 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
class MPWrapper(RawInterfaceWrapper):
 | 
			
		||||
 | 
			
		||||
    mp_config = {
 | 
			
		||||
        'ProMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'controller_type': 'velocity',
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'DMP': {
 | 
			
		||||
            'controller_kwargs': {
 | 
			
		||||
                'controller_type': 'velocity',
 | 
			
		||||
            },
 | 
			
		||||
            'trajectory_generator_kwargs': {
 | 
			
		||||
                'weight_scale': 50,
 | 
			
		||||
            },
 | 
			
		||||
            'phase_generator_kwargs': {
 | 
			
		||||
                'alpha_phase': 2,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        'ProDMP': {},
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
 | 
			
		||||
@ -12,30 +12,6 @@ from . import MPWrapper
 | 
			
		||||
 | 
			
		||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
 | 
			
		||||
 | 
			
		||||
    metadata = {
 | 
			
		||||
        'mp_config': {
 | 
			
		||||
            'ProMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'controller_type': 'velocity',
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'DMP': {
 | 
			
		||||
                'wrappers': [MPWrapper],
 | 
			
		||||
                'controller_kwargs': {
 | 
			
		||||
                    'controller_type': 'velocity',
 | 
			
		||||
                },
 | 
			
		||||
                'trajectory_generator_kwargs': {
 | 
			
		||||
                    'weight_scale': 50,
 | 
			
		||||
                },
 | 
			
		||||
                'phase_generator_kwargs': {
 | 
			
		||||
                    'alpha_phase': 2,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
            'ProDMP': {},
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
 | 
			
		||||
                 target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -142,14 +142,14 @@ def register_mp(id, mp_wrapper, mp_type):
 | 
			
		||||
def bb_env_constructor(underlying_id, mp_wrapper, mp_type, mp_config_override={}, **kwargs):
 | 
			
		||||
    raw_underlying_env = gym_make(underlying_id, **kwargs)
 | 
			
		||||
    underlying_env = mp_wrapper(raw_underlying_env)
 | 
			
		||||
    env_metadata = underlying_env.metadata
 | 
			
		||||
 | 
			
		||||
    metadata_config = copy.deepcopy(env_metadata.get('mp_config', {}).get(mp_type, {}))
 | 
			
		||||
    global_inherit_defaults = env_metadata.get('mp_config', {}).get('inherit_defaults', True)
 | 
			
		||||
    inherit_defaults = metadata_config.pop('inherit_defaults', global_inherit_defaults)
 | 
			
		||||
    mp_config = underlying_env.get('mp_config', {})
 | 
			
		||||
    active_mp_config = copy.deepcopy(mp_config.get(mp_type, {}))
 | 
			
		||||
    global_inherit_defaults = mp_config.get('inherit_defaults', True)
 | 
			
		||||
    inherit_defaults = active_mp_config.pop('inherit_defaults', global_inherit_defaults)
 | 
			
		||||
 | 
			
		||||
    config = copy.deepcopy(_BB_DEFAULTS[mp_type]) if inherit_defaults else {}
 | 
			
		||||
    nested_update(config, metadata_config)
 | 
			
		||||
    nested_update(config, active_mp_config)
 | 
			
		||||
    nested_update(config, mp_config_override)
 | 
			
		||||
 | 
			
		||||
    wrappers = config.pop("wrappers")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user