From a9a1d054977e9ed0231959d04b332812ab37ec85 Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Thu, 1 Dec 2022 11:46:09 +0100 Subject: [PATCH] merge master into table-tennis-dev branch --- fancy_gym/black_box/black_box_wrapper.py | 9 +++------ .../examples/examples_movement_primitives.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 4ff685a..d5bd7e6 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -145,11 +145,10 @@ class BlackBoxWrapper(gym.ObservationWrapper): def step(self, action: np.ndarray): """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" - # TODO remove this part, right now only needed for beer pong - # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen) position, velocity = self.get_trajectory(action) + position, velocity = self.env.set_episode_arguments(action, position, velocity) traj_is_valid = self.env.preprocessing_and_validity_callback(action, position, velocity) - # insert validation here + trajectory_length = len(position) rewards = np.zeros(shape=(trajectory_length,)) if self.verbose >= 2: @@ -167,8 +166,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): else: self.plan_steps += 1 for t, (pos, vel) in enumerate(zip(position, velocity)): - current_pos = self.current_pos - current_vel = self.current_vel step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) obs, c_reward, done, info = self.env.step(c_action) @@ -186,7 +183,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): if self.render_kwargs: self.env.render(**self.render_kwargs) - if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index 7d2edf3..b9f82de 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -161,16 +161,16 @@ if __name__ == '__main__': # ProMP # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) - # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) - # example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render) - # example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render) - # example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render) + example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) + example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render) + example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render) + example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render) # ProDMP - # example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) - example_mp("TableTennis4DProDMP-v0", seed=10, iterations=2000, render=render) - # example_mp("TableTennisWindProDMP-v0", seed=10, iterations=20, render=render) - # example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=20, render=render) + example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) + example_mp("TableTennis4DProDMP-v0", seed=10, iterations=20, render=render) + example_mp("TableTennisWindProDMP-v0", seed=10, iterations=20, render=render) + example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=20, render=render) # Altered basis functions # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)