add tests for replanning env & adapt observation space for box pushing & add max_planning_times to replanning tasks

This commit is contained in:
Hongyi Zhou 2022-11-01 22:51:43 +01:00
parent 4f9b1fad25
commit fd4f9ae0bc
8 changed files with 40 additions and 37 deletions

View File

@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
max_replan_times: int = 1,
max_planning_times: int = 1,
desired_conditioning: bool = False
):
"""
@ -61,8 +61,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# self.traj_gen.basis_gn.show_basis(plot=True)
# spaces
# self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
self.return_context_observation = True
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
# self.return_context_observation = True
self.traj_gen_action_space = self._get_traj_gen_action_space()
self.action_space = self._get_action_space()
@ -82,8 +82,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.condition_pos = None
self.condition_vel = None
self.max_replan_times = max_replan_times
self.replan_counts = 0
self.max_planning_times = max_planning_times
self.plan_counts = 0
def observation(self, observation):
# return context space if we are
@ -176,7 +176,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict()
done = False
self.replan_counts += 1
self.plan_counts += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
@ -197,12 +197,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
if self.desired_conditioning:
self.condition_pos = pos
self.condition_vel = vel
else:
self.condition_pos = self.current_pos
self.condition_vel = self.current_vel
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
continue
self.condition_pos = pos if self.desired_conditioning else self.current_pos
self.condition_vel = vel if self.desired_conditioning else self.current_vel
break
@ -229,5 +229,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
self.current_traj_steps = 0
self.plan_counts = 0
self.traj_gen.reset()
return super(BlackBoxWrapper, self).reset()

View File

@ -87,6 +87,7 @@ DEFAULT_BB_DICT_ProDMP = {
},
"black_box_kwargs": {
'replanning_schedule': None,
'max_planning_times': None,
'verbose': 2
}
}
@ -506,7 +507,7 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
register(
id=_env_id,

View File

@ -134,9 +134,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
obs = np.concatenate([
self.data.qpos[:7].copy(), # joint position
self.data.qvel[:7].copy(), # joint velocity
self.data.qfrc_bias[:7].copy(), # joint gravity compensation
self.data.site("rod_tip").xpos.copy(), # position of rod tip
self.data.body("push_rod").xquat.copy(), # orientation of rod
# self.data.qfrc_bias[:7].copy(), # joint gravity compensation
# self.data.site("rod_tip").xpos.copy(), # position of rod tip
# self.data.body("push_rod").xquat.copy(), # orientation of rod
self.data.body("box_0").xpos.copy(), # position of box
self.data.body("box_0").xquat.copy(), # orientation of box
self.data.body("replan_target_pos").xpos.copy(), # position of target

View File

@ -11,16 +11,13 @@ class MPWrapper(RawInterfaceWrapper):
@property
def context_mask(self):
return np.hstack([
[True] * 7, # joints position
[True] * 7, # joints velocity
[False] * 7, # joints gravity compensation
[False] * 3, # position of rod tip
[False] * 4, # orientation of rod
[True] * 3, # position of box
[True] * 4, # orientation of box
[False] * 7, # joints position
[False] * 7, # joints velocity
[False] * 3, # position of box
[False] * 4, # orientation of box
[True] * 3, # position of target
[True] * 4, # orientation of target
[True] * 1, # time
# [True] * 1, # time
])
@property

View File

@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = True
# DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
# ProMP
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
# ProDMP
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
# Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
# Custom MP
# example_fully_custom_mp(seed=10, iterations=1, render=render)
example_fully_custom_mp(seed=10, iterations=1, render=render)

View File

@ -43,6 +43,10 @@ DEFAULT_BB_DICT_ProDMP = {
"basis_generator_kwargs": {
'basis_generator_type': 'prodmp',
'num_basis': 5
},
"black_box_kwargs": {
'replanning_schedule': None,
'max_planning_times': None,
}
}

View File

@ -11,13 +11,13 @@ All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
def test_replanning_envs(env_id: str):
"""Tests that ProDMP environments run without errors using random actions."""
run_env(env_id, iterations=4)
run_env(env_id)
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_determinism(env_id: str):
# """Tests that ProDMP environments are deterministic."""
# run_env_determinism(env_id, 0)
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
def test_replanning_determinism(env_id: str):
"""Tests that ProDMP environments are deterministic."""
run_env_determinism(env_id, 0)

View File

@ -49,8 +49,8 @@ def run_env(env_id, iterations=None, seed=0, render=False):
if done:
break
assert done, "Done flag is not True after end of episode."
if not hasattr(env, "replanning_schedule"):
assert done, "Done flag is not True after end of episode."
observations.append(obs)
env.close()
del env