From 2735e0bf24fa78ea2a45aa92bf09bf98e42d2b44 Mon Sep 17 00:00:00 2001 From: Hongyi Zhou Date: Fri, 25 Nov 2022 22:34:46 +0100 Subject: [PATCH] add contextual obs option to invalid trajectory callback --- fancy_gym/black_box/black_box_wrapper.py | 17 ++++++++--------- .../envs/mujoco/table_tennis/mp_wrapper.py | 6 ++++-- .../examples/examples_movement_primitives.py | 13 +++++++------ 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 1dddf2c..a8baa84 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -169,8 +169,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): infos = dict() done = False - - if traj_is_valid: self.plan_steps += 1 for t, (pos, vel) in enumerate(zip(position, velocity)): @@ -207,18 +205,19 @@ class BlackBoxWrapper(gym.ObservationWrapper): infos.update({k: v[:t+1] for k, v in infos.items()}) self.current_traj_steps += t + 1 - if self.verbose >= 2: - infos['positions'] = position - infos['velocities'] = velocity - infos['step_actions'] = actions[:t + 1] - infos['step_observations'] = observations[:t + 1] - infos['step_rewards'] = rewards[:t + 1] + if self.verbose >= 2: + infos['positions'] = position + infos['velocities'] = velocity + infos['step_actions'] = actions[:t + 1] + infos['step_observations'] = observations[:t + 1] + infos['step_rewards'] = rewards[:t + 1] infos['trajectory_length'] = t + 1 trajectory_return = self.reward_aggregation(rewards[:t + 1]) return self.observation(obs), trajectory_return, done, infos else: - obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) + obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity, + self.return_context_observation) return self.observation(obs), trajectory_return, done, infos def render(self, **kwargs): diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index dcb2306..1da8de5 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -55,8 +55,8 @@ class MPWrapper(RawInterfaceWrapper): return False return True - def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ - -> Tuple[np.ndarray, float, bool, dict]: + def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray, + return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]: tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) @@ -64,6 +64,8 @@ class MPWrapper(RawInterfaceWrapper): invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ violate_high_bound_error + violate_low_bound_error obs = np.concatenate([self.get_obs(), np.array([0])]) + if return_contextual_obs: + obs = self.get_obs() return obs, -invalid_penalty, True, { "hit_ball": [False], "ball_returned_success": [False], diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py index 445b8b9..7ac6c69 100644 --- a/fancy_gym/examples/examples_movement_primitives.py +++ b/fancy_gym/examples/examples_movement_primitives.py @@ -157,17 +157,18 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': render = True # DMP - example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) + # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # ProMP - example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) - example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) + # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) + # example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) + example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render) # ProDMP - example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) + # example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render) # Altered basis functions - obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) + # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) # Custom MP - example_fully_custom_mp(seed=10, iterations=1, render=render) + # example_fully_custom_mp(seed=10, iterations=1, render=render)