diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index dc3d0d5..a73915e 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -22,6 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): replanning_schedule: Optional[ Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, reward_aggregation: Callable[[np.ndarray], float] = np.sum, + max_replan_times: int = 1, desired_conditioning: bool = False ): """ @@ -73,6 +74,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.condition_pos = None self.condition_vel = None + self.max_replan_times = max_replan_times + self.replan_counts = 0 + def observation(self, observation): # return context space if we are if self.return_context_observation: @@ -155,6 +159,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): infos = dict() done = False + self.replan_counts += 1 for t, (pos, vel) in enumerate(zip(position, velocity)): step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) @@ -173,14 +178,16 @@ class BlackBoxWrapper(gym.ObservationWrapper): if self.render_kwargs: self.env.render(**self.render_kwargs) - if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, - t + 1 + self.current_traj_steps): + if done or (self.replan_counts < self.max_replan_times + and self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + t + 1 + self.current_traj_steps)): if self.desired_conditioning: self.condition_pos = pos self.condition_vel = vel else: self.condition_pos = self.current_pos self.condition_vel = self.current_vel + break infos.update({k: v[:t+1] for k, v in infos.items()}) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index d38194b..e2bea2e 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -498,11 +498,12 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02]) # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01 - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([100., 166., 500., 1000.]) - kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1. + kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.]) + kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1. kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10. - kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t: t % 25 == 0 + kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4 + kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0 register( id=_env_id, entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper', diff --git a/test/test_replanning_envs.py b/test/test_replanning_envs.py index 227e885..c6d697e 100644 --- a/test/test_replanning_envs.py +++ b/test/test_replanning_envs.py @@ -9,7 +9,15 @@ Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP'] All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP'] + + @pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS) -def test_prodmp_envs(env_id: str): +def test_replanning_envs(env_id: str): """Tests that ProDMP environments run without errors using random actions.""" - run_env(env_id) \ No newline at end of file + run_env(env_id, iterations=4) + +# @pytest.mark.parametrize('env_id', All_ProDMP_IDS) +# def test_replanning_determinism(env_id: str): +# """Tests that ProDMP environments are deterministic.""" +# run_env_determinism(env_id, 0) +