add tests for replanning env & adapt observation space for box pushing & add max_planning_times to replanning tasks
This commit is contained in:
parent
4f9b1fad25
commit
fd4f9ae0bc
@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
replanning_schedule: Optional[
|
replanning_schedule: Optional[
|
||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
max_replan_times: int = 1,
|
max_planning_times: int = 1,
|
||||||
desired_conditioning: bool = False
|
desired_conditioning: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -61,8 +61,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
# self.traj_gen.basis_gn.show_basis(plot=True)
|
# self.traj_gen.basis_gn.show_basis(plot=True)
|
||||||
# spaces
|
# spaces
|
||||||
# self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||||
self.return_context_observation = True
|
# self.return_context_observation = True
|
||||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||||
self.action_space = self._get_action_space()
|
self.action_space = self._get_action_space()
|
||||||
|
|
||||||
@ -82,8 +82,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
self.condition_vel = None
|
self.condition_vel = None
|
||||||
|
|
||||||
self.max_replan_times = max_replan_times
|
self.max_planning_times = max_planning_times
|
||||||
self.replan_counts = 0
|
self.plan_counts = 0
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
@ -176,7 +176,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
self.replan_counts += 1
|
self.plan_counts += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||||
@ -197,12 +197,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps):
|
||||||
if self.desired_conditioning:
|
|
||||||
self.condition_pos = pos
|
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
|
||||||
self.condition_vel = vel
|
continue
|
||||||
else:
|
|
||||||
self.condition_pos = self.current_pos
|
self.condition_pos = pos if self.desired_conditioning else self.current_pos
|
||||||
self.condition_vel = self.current_vel
|
self.condition_vel = vel if self.desired_conditioning else self.current_vel
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -229,5 +229,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
|
self.plan_counts = 0
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
@ -87,6 +87,7 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
},
|
},
|
||||||
"black_box_kwargs": {
|
"black_box_kwargs": {
|
||||||
'replanning_schedule': None,
|
'replanning_schedule': None,
|
||||||
|
'max_planning_times': None,
|
||||||
'verbose': 2
|
'verbose': 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -506,7 +507,7 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
|
@ -134,9 +134,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
obs = np.concatenate([
|
obs = np.concatenate([
|
||||||
self.data.qpos[:7].copy(), # joint position
|
self.data.qpos[:7].copy(), # joint position
|
||||||
self.data.qvel[:7].copy(), # joint velocity
|
self.data.qvel[:7].copy(), # joint velocity
|
||||||
self.data.qfrc_bias[:7].copy(), # joint gravity compensation
|
# self.data.qfrc_bias[:7].copy(), # joint gravity compensation
|
||||||
self.data.site("rod_tip").xpos.copy(), # position of rod tip
|
# self.data.site("rod_tip").xpos.copy(), # position of rod tip
|
||||||
self.data.body("push_rod").xquat.copy(), # orientation of rod
|
# self.data.body("push_rod").xquat.copy(), # orientation of rod
|
||||||
self.data.body("box_0").xpos.copy(), # position of box
|
self.data.body("box_0").xpos.copy(), # position of box
|
||||||
self.data.body("box_0").xquat.copy(), # orientation of box
|
self.data.body("box_0").xquat.copy(), # orientation of box
|
||||||
self.data.body("replan_target_pos").xpos.copy(), # position of target
|
self.data.body("replan_target_pos").xpos.copy(), # position of target
|
||||||
|
@ -11,16 +11,13 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
@property
|
@property
|
||||||
def context_mask(self):
|
def context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[True] * 7, # joints position
|
[False] * 7, # joints position
|
||||||
[True] * 7, # joints velocity
|
[False] * 7, # joints velocity
|
||||||
[False] * 7, # joints gravity compensation
|
[False] * 3, # position of box
|
||||||
[False] * 3, # position of rod tip
|
[False] * 4, # orientation of box
|
||||||
[False] * 4, # orientation of rod
|
|
||||||
[True] * 3, # position of box
|
|
||||||
[True] * 4, # orientation of box
|
|
||||||
[True] * 3, # position of target
|
[True] * 3, # position of target
|
||||||
[True] * 4, # orientation of target
|
[True] * 4, # orientation of target
|
||||||
[True] * 1, # time
|
# [True] * 1, # time
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = True
|
render = True
|
||||||
# DMP
|
# DMP
|
||||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom MP
|
# Custom MP
|
||||||
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||||
|
@ -43,6 +43,10 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
"basis_generator_kwargs": {
|
"basis_generator_kwargs": {
|
||||||
'basis_generator_type': 'prodmp',
|
'basis_generator_type': 'prodmp',
|
||||||
'num_basis': 5
|
'num_basis': 5
|
||||||
|
},
|
||||||
|
"black_box_kwargs": {
|
||||||
|
'replanning_schedule': None,
|
||||||
|
'max_planning_times': None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,13 +11,13 @@ All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
|
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||||
def test_replanning_envs(env_id: str):
|
def test_replanning_envs(env_id: str):
|
||||||
"""Tests that ProDMP environments run without errors using random actions."""
|
"""Tests that ProDMP environments run without errors using random actions."""
|
||||||
run_env(env_id, iterations=4)
|
run_env(env_id)
|
||||||
|
|
||||||
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||||
# def test_replanning_determinism(env_id: str):
|
def test_replanning_determinism(env_id: str):
|
||||||
# """Tests that ProDMP environments are deterministic."""
|
"""Tests that ProDMP environments are deterministic."""
|
||||||
# run_env_determinism(env_id, 0)
|
run_env_determinism(env_id, 0)
|
||||||
|
|
||||||
|
@ -49,8 +49,8 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
|||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
if not hasattr(env, "replanning_schedule"):
|
||||||
assert done, "Done flag is not True after end of episode."
|
assert done, "Done flag is not True after end of episode."
|
||||||
observations.append(obs)
|
observations.append(obs)
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
Loading…
Reference in New Issue
Block a user