diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 2efff28..bb1cd28 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -21,7 +21,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): learn_sub_trajectories: bool = False, replanning_schedule: Optional[ Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None, - reward_aggregation: Callable[[np.ndarray], float] = np.sum + reward_aggregation: Callable[[np.ndarray], float] = np.sum, + desired_conditioning: bool = False ): """ gym.Wrapper for leveraging a black box approach with a trajectory generator. @@ -67,6 +68,11 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.render_kwargs = {} self.verbose = verbose + # condition value + self.desired_conditioning = True + self.condition_pos = None + self.condition_vel = None + def observation(self, observation): # return context space if we are if self.return_context_observation: @@ -87,7 +93,11 @@ class BlackBoxWrapper(gym.ObservationWrapper): bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt) # TODO we could think about initializing with the previous desired value in order to have a smooth transition # at least from the planning point of view. - self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) + # self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel) + if self.current_traj_steps == 0: + self.condition_pos = self.current_pos + self.condition_vel = self.current_vel + self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel) self.traj_gen.set_duration(duration, self.dt) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) position = get_numpy(self.traj_gen.get_traj_pos()) @@ -165,14 +175,22 @@ class BlackBoxWrapper(gym.ObservationWrapper): if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, t + 1 + self.current_traj_steps): + if self.desired_conditioning: + self.condition_pos = pos + self.condition_vel = vel + else: + self.condition_pos = self.current_pos + self.condition_vel = self.current_vel break infos.update({k: v[:t+1] for k, v in infos.items()}) self.current_traj_steps += t + 1 if self.verbose >= 2: - infos['positions'] = position - infos['velocities'] = velocity + infos['desired_pos'] = position[:t+1] + infos['desired_vel'] = velocity[:t+1] + infos['current_pos'] = self.current_pos + infos['current_vel'] = self.current_vel infos['step_actions'] = actions[:t + 1] infos['step_observations'] = observations[:t + 1] infos['step_rewards'] = rewards[:t + 1] diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index eea6455..834c5e0 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): if __name__=="__main__": env = BoxPushingTemporalSpatialSparse(frame_skip=10) env.reset() - for i in range(1): + for i in range(10): env.reset() for _ in range(100): env.render("human") diff --git a/fancy_gym/examples/example_replanning_envs.py b/fancy_gym/examples/example_replanning_envs.py new file mode 100644 index 0000000..392e9d4 --- /dev/null +++ b/fancy_gym/examples/example_replanning_envs.py @@ -0,0 +1,38 @@ +import fancy_gym +import numpy as np +import matplotlib.pyplot as plt + +def plot_trajectory(traj): + plt.figure() + plt.plot(traj[:, 3]) + plt.legend() + plt.show() + +def run_replanning_envs(env_name="BoxPushingProDMP-v0", seed=1, iterations=1, render=True): + env = fancy_gym.make(env_name, seed=seed) + env.reset() + for i in range(iterations): + done = False + desired_pos_traj = np.zeros((100, 7)) + desired_vel_traj = np.zeros((100, 7)) + real_pos_traj = np.zeros((100, 7)) + real_vel_traj = np.zeros((100, 7)) + t = 0 + while done is False: + ac = env.action_space.sample() + obs, reward, done, info = env.step(ac) + desired_pos_traj[t: t + 25, :] = info['desired_pos'] + desired_vel_traj[t: t + 25, :] = info['desired_vel'] + # real_pos_traj.append(info['current_pos']) + # real_vel_traj.append(info['current_vel']) + t += 25 + if render: + env.render(mode="human") + if done: + env.reset() + plot_trajectory(desired_pos_traj) + env.close() + del env + +if __name__ == "__main__": + run_replanning_envs(env_name="BoxPushingDenseProDMP-v0", seed=1, iterations=1, render=False) \ No newline at end of file