This commit is contained in:
Hongyi Zhou 2022-10-26 15:18:37 +02:00
parent a8ffa791b8
commit be6137ec81
3 changed files with 23 additions and 7 deletions

View File

@ -22,6 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
max_replan_times: int = 1,
desired_conditioning: bool = False
):
"""
@ -73,6 +74,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.condition_pos = None
self.condition_vel = None
self.max_replan_times = max_replan_times
self.replan_counts = 0
def observation(self, observation):
# return context space if we are
if self.return_context_observation:
@ -155,6 +159,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict()
done = False
self.replan_counts += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
@ -173,14 +178,16 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
if done or (self.replan_counts < self.max_replan_times
and self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps)):
if self.desired_conditioning:
self.condition_pos = pos
self.condition_vel = vel
else:
self.condition_pos = self.current_pos
self.condition_vel = self.current_vel
break
infos.update({k: v[:t+1] for k, v in infos.items()})

View File

@ -498,10 +498,11 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([100., 166., 500., 1000.])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
register(
id=_env_id,

View File

@ -9,7 +9,15 @@ Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
def test_prodmp_envs(env_id: str):
def test_replanning_envs(env_id: str):
"""Tests that ProDMP environments run without errors using random actions."""
run_env(env_id)
run_env(env_id, iterations=4)
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
# def test_replanning_determinism(env_id: str):
# """Tests that ProDMP environments are deterministic."""
# run_env_determinism(env_id, 0)