From fd4f9ae0bcb9193c00367eff8bc2097d50d97d8b Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Tue, 1 Nov 2022 22:51:43 +0100 Subject: [PATCH] add tests for replanning env & adapt observation space for box pushing & add max_planning_times to replanning tasks --- fancy_gym/black_box/black_box_wrapper.py | 25 ++++++++++--------- fancy_gym/envs/__init__.py | 3 ++- .../mujoco/box_pushing/box_pushing_env.py | 6 ++--- .../envs/mujoco/box_pushing/mp_wrapper.py | 13 ++++------ .../examples/examples_movement_primitives.py | 10 ++++---- fancy_gym/meta/__init__.py | 4 +++ test/test_replanning_envs.py | 12 ++++----- test/utils.py | 4 +-- 8 files changed, 40 insertions(+), 37 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 5269d29..dc5445e 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): replanning_schedule: Optional[ Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, reward_aggregation: Callable[[np.ndarray], float] = np.sum, - max_replan_times: int = 1, + max_planning_times: int = 1, desired_conditioning: bool = False ): """ @@ -61,8 +61,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): # self.traj_gen.basis_gn.show_basis(plot=True) # spaces - # self.return_context_observation = not (learn_sub_trajectories or self.do_replanning) - self.return_context_observation = True + self.return_context_observation = not (learn_sub_trajectories or self.do_replanning) + # self.return_context_observation = True self.traj_gen_action_space = self._get_traj_gen_action_space() self.action_space = self._get_action_space() @@ -82,8 +82,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.condition_pos = None self.condition_vel = None - self.max_replan_times = max_replan_times - self.replan_counts = 0 + self.max_planning_times = max_planning_times + self.plan_counts = 0 def observation(self, observation): # return context space if we are @@ -176,7 +176,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): infos = dict() done = False - self.replan_counts += 1 + self.plan_counts += 1 for t, (pos, vel) in enumerate(zip(position, velocity)): step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) @@ -197,12 +197,12 @@ class BlackBoxWrapper(gym.ObservationWrapper): if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, t + 1 + self.current_traj_steps): - if self.desired_conditioning: - self.condition_pos = pos - self.condition_vel = vel - else: - self.condition_pos = self.current_pos - self.condition_vel = self.current_vel + + if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times: + continue + + self.condition_pos = pos if self.desired_conditioning else self.current_pos + self.condition_vel = vel if self.desired_conditioning else self.current_vel break @@ -229,5 +229,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None): self.current_traj_steps = 0 + self.plan_counts = 0 self.traj_gen.reset() return super(BlackBoxWrapper, self).reset() diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 91e41db..d3dfa8e 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -87,6 +87,7 @@ DEFAULT_BB_DICT_ProDMP = { }, "black_box_kwargs": { 'replanning_schedule': None, + 'max_planning_times': None, 'verbose': 2 } } @@ -506,7 +507,7 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10. kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 - kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4 + kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0 register( id=_env_id, diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 834c5e0..37babf9 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -134,9 +134,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): obs = np.concatenate([ self.data.qpos[:7].copy(), # joint position self.data.qvel[:7].copy(), # joint velocity - self.data.qfrc_bias[:7].copy(), # joint gravity compensation - self.data.site("rod_tip").xpos.copy(), # position of rod tip - self.data.body("push_rod").xquat.copy(), # orientation of rod + # self.data.qfrc_bias[:7].copy(), # joint gravity compensation + # self.data.site("rod_tip").xpos.copy(), # position of rod tip + # self.data.body("push_rod").xquat.copy(), # orientation of rod self.data.body("box_0").xpos.copy(), # position of box self.data.body("box_0").xquat.copy(), # orientation of box self.data.body("replan_target_pos").xpos.copy(), # position of target diff --git a/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py index ccf4c57..09b2d65 100644 --- a/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py @@ -11,16 +11,13 @@ class MPWrapper(RawInterfaceWrapper): @property def context_mask(self): return np.hstack([ - [True] * 7, # joints position - [True] * 7, # joints velocity - [False] * 7, # joints gravity compensation - [False] * 3, # position of rod tip - [False] * 4, # orientation of rod - [True] * 3, # position of box - [True] * 4, # orientation of box + [False] * 7, # joints position + [False] * 7, # joints velocity + [False] * 3, # position of box + [False] * 4, # orientation of box [True] * 3, # position of target [True] * 4, # orientation of target - [True] * 1, # time + # [True] * 1, # time ]) @property diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index 707dccd..e19eacb 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': render = True # DMP - # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) + example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # ProMP - # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) - # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) + example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) + example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) # ProDMP example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render) # Altered basis functions - # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) + obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # Custom MP - # example_fully_custom_mp(seed=10, iterations=1, render=render) + example_fully_custom_mp(seed=10, iterations=1, render=render) diff --git a/fancy_gym/meta/__init__.py b/fancy_gym/meta/__init__.py index 4fb23b2..b9f0dca 100644 --- a/fancy_gym/meta/__init__.py +++ b/fancy_gym/meta/__init__.py @@ -43,6 +43,10 @@ DEFAULT_BB_DICT_ProDMP = { "basis_generator_kwargs": { 'basis_generator_type': 'prodmp', 'num_basis': 5 + }, + "black_box_kwargs": { + 'replanning_schedule': None, + 'max_planning_times': None, } } diff --git a/test/test_replanning_envs.py b/test/test_replanning_envs.py index c6d697e..300faed 100644 --- a/test/test_replanning_envs.py +++ b/test/test_replanning_envs.py @@ -11,13 +11,13 @@ All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP'] -@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS) +@pytest.mark.parametrize('env_id', All_ProDMP_IDS) def test_replanning_envs(env_id: str): """Tests that ProDMP environments run without errors using random actions.""" - run_env(env_id, iterations=4) + run_env(env_id) -# @pytest.mark.parametrize('env_id', All_ProDMP_IDS) -# def test_replanning_determinism(env_id: str): -# """Tests that ProDMP environments are deterministic.""" -# run_env_determinism(env_id, 0) +@pytest.mark.parametrize('env_id', All_ProDMP_IDS) +def test_replanning_determinism(env_id: str): + """Tests that ProDMP environments are deterministic.""" + run_env_determinism(env_id, 0) diff --git a/test/utils.py b/test/utils.py index 7ed8d61..dff2292 100644 --- a/test/utils.py +++ b/test/utils.py @@ -49,8 +49,8 @@ def run_env(env_id, iterations=None, seed=0, render=False): if done: break - - assert done, "Done flag is not True after end of episode." + if not hasattr(env, "replanning_schedule"): + assert done, "Done flag is not True after end of episode." observations.append(obs) env.close() del env