add tests for replanning env & adapt observation space for box pushing & add max_planning_times to replanning tasks
This commit is contained in:
parent
4f9b1fad25
commit
fd4f9ae0bc
@ -23,7 +23,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
replanning_schedule: Optional[
|
||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||
max_replan_times: int = 1,
|
||||
max_planning_times: int = 1,
|
||||
desired_conditioning: bool = False
|
||||
):
|
||||
"""
|
||||
@ -61,8 +61,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
# self.traj_gen.basis_gn.show_basis(plot=True)
|
||||
# spaces
|
||||
# self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||
self.return_context_observation = True
|
||||
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||
# self.return_context_observation = True
|
||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||
self.action_space = self._get_action_space()
|
||||
|
||||
@ -82,8 +82,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
|
||||
self.max_replan_times = max_replan_times
|
||||
self.replan_counts = 0
|
||||
self.max_planning_times = max_planning_times
|
||||
self.plan_counts = 0
|
||||
|
||||
def observation(self, observation):
|
||||
# return context space if we are
|
||||
@ -176,7 +176,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
infos = dict()
|
||||
done = False
|
||||
|
||||
self.replan_counts += 1
|
||||
self.plan_counts += 1
|
||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
@ -197,12 +197,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps):
|
||||
if self.desired_conditioning:
|
||||
self.condition_pos = pos
|
||||
self.condition_vel = vel
|
||||
else:
|
||||
self.condition_pos = self.current_pos
|
||||
self.condition_vel = self.current_vel
|
||||
|
||||
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
|
||||
continue
|
||||
|
||||
self.condition_pos = pos if self.desired_conditioning else self.current_pos
|
||||
self.condition_vel = vel if self.desired_conditioning else self.current_vel
|
||||
|
||||
break
|
||||
|
||||
@ -229,5 +229,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
self.current_traj_steps = 0
|
||||
self.plan_counts = 0
|
||||
self.traj_gen.reset()
|
||||
return super(BlackBoxWrapper, self).reset()
|
||||
|
@ -87,6 +87,7 @@ DEFAULT_BB_DICT_ProDMP = {
|
||||
},
|
||||
"black_box_kwargs": {
|
||||
'replanning_schedule': None,
|
||||
'max_planning_times': None,
|
||||
'verbose': 2
|
||||
}
|
||||
}
|
||||
@ -506,7 +507,7 @@ for _v in _versions:
|
||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||
register(
|
||||
id=_env_id,
|
||||
|
@ -134,9 +134,9 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
obs = np.concatenate([
|
||||
self.data.qpos[:7].copy(), # joint position
|
||||
self.data.qvel[:7].copy(), # joint velocity
|
||||
self.data.qfrc_bias[:7].copy(), # joint gravity compensation
|
||||
self.data.site("rod_tip").xpos.copy(), # position of rod tip
|
||||
self.data.body("push_rod").xquat.copy(), # orientation of rod
|
||||
# self.data.qfrc_bias[:7].copy(), # joint gravity compensation
|
||||
# self.data.site("rod_tip").xpos.copy(), # position of rod tip
|
||||
# self.data.body("push_rod").xquat.copy(), # orientation of rod
|
||||
self.data.body("box_0").xpos.copy(), # position of box
|
||||
self.data.body("box_0").xquat.copy(), # orientation of box
|
||||
self.data.body("replan_target_pos").xpos.copy(), # position of target
|
||||
|
@ -11,16 +11,13 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
@property
|
||||
def context_mask(self):
|
||||
return np.hstack([
|
||||
[True] * 7, # joints position
|
||||
[True] * 7, # joints velocity
|
||||
[False] * 7, # joints gravity compensation
|
||||
[False] * 3, # position of rod tip
|
||||
[False] * 4, # orientation of rod
|
||||
[True] * 3, # position of box
|
||||
[True] * 4, # orientation of box
|
||||
[False] * 7, # joints position
|
||||
[False] * 7, # joints velocity
|
||||
[False] * 3, # position of box
|
||||
[False] * 4, # orientation of box
|
||||
[True] * 3, # position of target
|
||||
[True] * 4, # orientation of target
|
||||
[True] * 1, # time
|
||||
# [True] * 1, # time
|
||||
])
|
||||
|
||||
@property
|
||||
|
@ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
if __name__ == '__main__':
|
||||
render = True
|
||||
# DMP
|
||||
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
|
||||
|
||||
# ProMP
|
||||
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
|
||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# ProDMP
|
||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=16, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom MP
|
||||
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||
|
@ -43,6 +43,10 @@ DEFAULT_BB_DICT_ProDMP = {
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'prodmp',
|
||||
'num_basis': 5
|
||||
},
|
||||
"black_box_kwargs": {
|
||||
'replanning_schedule': None,
|
||||
'max_planning_times': None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -11,13 +11,13 @@ All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
|
||||
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||
def test_replanning_envs(env_id: str):
|
||||
"""Tests that ProDMP environments run without errors using random actions."""
|
||||
run_env(env_id, iterations=4)
|
||||
run_env(env_id)
|
||||
|
||||
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||
# def test_replanning_determinism(env_id: str):
|
||||
# """Tests that ProDMP environments are deterministic."""
|
||||
# run_env_determinism(env_id, 0)
|
||||
@pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||
def test_replanning_determinism(env_id: str):
|
||||
"""Tests that ProDMP environments are deterministic."""
|
||||
run_env_determinism(env_id, 0)
|
||||
|
||||
|
@ -49,8 +49,8 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
assert done, "Done flag is not True after end of episode."
|
||||
if not hasattr(env, "replanning_schedule"):
|
||||
assert done, "Done flag is not True after end of episode."
|
||||
observations.append(obs)
|
||||
env.close()
|
||||
del env
|
||||
|
Loading…
Reference in New Issue
Block a user