From 344c11d67a71c1fb61bb49771dfce5a0e770813b Mon Sep 17 00:00:00 2001 From: "hongyi.zhou" Date: Fri, 27 Jan 2023 17:50:14 +0100 Subject: [PATCH] updates according to changes request --- .gitignore | 3 + fancy_gym/black_box/black_box_wrapper.py | 78 +++++++++---------- fancy_gym/black_box/raw_interface_wrapper.py | 28 +++++++ .../mujoco/table_tennis/table_tennis_env.py | 70 ++++++++--------- .../mujoco/table_tennis/table_tennis_utils.py | 2 +- 5 files changed, 103 insertions(+), 78 deletions(-) diff --git a/.gitignore b/.gitignore index ec01816..91a91dd 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,6 @@ venv.bak/ /configs/db.cfg legacy/ MUJOCO_LOG.TXT + +# vscode +.vscode diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index acfc00e..fc295e7 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -159,52 +159,52 @@ class BlackBoxWrapper(gym.ObservationWrapper): infos = dict() done = False - if traj_is_valid is False: + if not traj_is_valid: obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity, self.return_context_observation) return self.observation(obs), trajectory_return, done, infos - else: - self.plan_steps += 1 - for t, (pos, vel) in enumerate(zip(position, velocity)): - step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) - c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) - obs, c_reward, done, info = self.env.step(c_action) - rewards[t] = c_reward - if self.verbose >= 2: - actions[t, :] = c_action - observations[t, :] = obs - - for k, v in info.items(): - elems = infos.get(k, [None] * trajectory_length) - elems[t] = v - infos[k] = elems - - if self.render_kwargs: - self.env.render(**self.render_kwargs) - - if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, - t + 1 + self.current_traj_steps) - and self.plan_steps < self.max_planning_times): - - self.condition_pos = pos if self.condition_on_desired else None - self.condition_vel = vel if self.condition_on_desired else None - - break - - infos.update({k: v[:t+1] for k, v in infos.items()}) - self.current_traj_steps += t + 1 + self.plan_steps += 1 + for t, (pos, vel) in enumerate(zip(position, velocity)): + step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) + c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) + obs, c_reward, done, info = self.env.step(c_action) + rewards[t] = c_reward if self.verbose >= 2: - infos['positions'] = position - infos['velocities'] = velocity - infos['step_actions'] = actions[:t + 1] - infos['step_observations'] = observations[:t + 1] - infos['step_rewards'] = rewards[:t + 1] + actions[t, :] = c_action + observations[t, :] = obs - infos['trajectory_length'] = t + 1 - trajectory_return = self.reward_aggregation(rewards[:t + 1]) - return self.observation(obs), trajectory_return, done, infos + for k, v in info.items(): + elems = infos.get(k, [None] * trajectory_length) + elems[t] = v + infos[k] = elems + + if self.render_kwargs: + self.env.render(**self.render_kwargs) + + if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action, + t + 1 + self.current_traj_steps) + and self.plan_steps < self.max_planning_times): + + self.condition_pos = pos if self.condition_on_desired else None + self.condition_vel = vel if self.condition_on_desired else None + + break + + infos.update({k: v[:t+1] for k, v in infos.items()}) + self.current_traj_steps += t + 1 + + if self.verbose >= 2: + infos['positions'] = position + infos['velocities'] = velocity + infos['step_actions'] = actions[:t + 1] + infos['step_observations'] = observations[:t + 1] + infos['step_rewards'] = rewards[:t + 1] + + infos['trajectory_length'] = t + 1 + trajectory_return = self.reward_aggregation(rewards[:t + 1]) + return self.observation(obs), trajectory_return, done, infos def render(self, **kwargs): """Only set render options here, such that they can be used during the rollout. diff --git a/fancy_gym/black_box/raw_interface_wrapper.py b/fancy_gym/black_box/raw_interface_wrapper.py index f41faab..7647924 100644 --- a/fancy_gym/black_box/raw_interface_wrapper.py +++ b/fancy_gym/black_box/raw_interface_wrapper.py @@ -56,12 +56,30 @@ class RawInterfaceWrapper(gym.Wrapper): -> Tuple[bool, np.ndarray, np.ndarray]: """ Used to preprocess the action and check if the desired trajectory is valid. + Args: + action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if + specified, else only traj_gen parameters + pos_traj: a vector instance of the raw position trajectory + vel_traj: a vector instance of the raw velocity trajectory + Returns: + validity flag: bool, True if the raw trajectory is valid, False if not + pos_traj: a vector instance of the preprocessed position trajectory + vel_traj: a vector instance of the preprocessed velocity trajectory """ return True, pos_traj, vel_traj def set_episode_arguments(self, action, pos_traj, vel_traj): """ Used to set the arguments for env that valid for the whole episode + deprecated, replaced by preprocessing_and_validity_callback + Args: + action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if + specified, else only traj_gen parameters + pos_traj: a vector instance of the raw position trajectory + vel_traj: a vector instance of the raw velocity trajectory + Returns: + pos_traj: a vector instance of the preprocessed position trajectory + vel_traj: a vector instance of the preprocessed velocity trajectory """ return pos_traj, vel_traj @@ -82,5 +100,15 @@ class RawInterfaceWrapper(gym.Wrapper): def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: """ Used to return a artificial return from the env if the desired trajectory is invalid. + Args: + action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if + specified, else only traj_gen parameters + pos_traj: a vector instance of the raw position trajectory + vel_traj: a vector instance of the raw velocity trajectory + Returns: + obs: artificial observation if the trajectory is invalid, by default a zero vector + reward: artificial reward if the trajectory is invalid, by default 0 + done: artificial done if the trajectory is invalid, by default True + info: artificial info if the trajectory is invalid, by default empty dict """ return np.zeros(1), 0, True, {} \ No newline at end of file diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index dc717c2..7fb5e9f 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -4,7 +4,7 @@ import numpy as np from gym import utils, spaces from gym.envs.mujoco import MujocoEnv -from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force +from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound import mujoco @@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._ball_return_success = False self._ball_landing_pos = None self._init_ball_state = None - self._episode_end = False + self._terminated = False self._id_set = False @@ -54,6 +54,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), frame_skip=frame_skip, mujoco_bindings="mujoco") + if ctxt_dim == 2: self.context_bounds = CONTEXT_BOUNDS_2DIMS elif ctxt_dim == 4: @@ -81,7 +82,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): unstable_simulation = False - if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5: + if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5: new_goal_pos = self._generate_goal_pos(random=True) new_goal_pos[1] = -new_goal_pos[1] self._goal_pos = new_goal_pos @@ -96,7 +97,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): except Exception as e: print("Simulation get unstable return with MujocoException: ", e) unstable_simulation = True - self._episode_end = True + self._terminated = True break if not self._hit_ball: @@ -106,33 +107,32 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) if ball_land_on_floor_no_hit: self._ball_landing_pos = self.data.body("target_ball").xpos.copy() - self._episode_end = True + self._terminated = True if self._hit_ball and not self._ball_contact_after_hit: - if not self._ball_contact_after_hit: - if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor - self._ball_contact_after_hit = True - self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() - self._episode_end = True - elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table - self._ball_contact_after_hit = True - self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() - if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side - self._ball_return_success = True - self._episode_end = True + if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor + self._ball_contact_after_hit = True + self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() + self._terminated = True + elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table + self._ball_contact_after_hit = True + self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() + if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side + self._ball_return_success = True + self._terminated = True # update ball trajectory & racket trajectory self._ball_traj.append(self.data.body("target_ball").xpos.copy()) self._racket_traj.append(self.data.geom("bat").xpos.copy()) self._steps += 1 - self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end + self._terminated = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._terminated - reward = -25 if unstable_simulation else self._get_reward(self._episode_end) + reward = -25 if unstable_simulation else self._get_reward(self._terminated) land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ if self._ball_landing_pos is not None else 10. - return self._get_obs(), reward, self._episode_end, { + return self._get_obs(), reward, self._terminated, { "hit_ball": self._hit_ball, "ball_returned_success": self._ball_return_success, "land_dist_error": land_dist_err, @@ -173,7 +173,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._ball_contact_after_hit = False self._ball_return_success = False self._ball_landing_pos = None - self._episode_end = False + self._terminated = False self._ball_traj = [] self._racket_traj = [] return self._get_obs() @@ -195,24 +195,18 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): ]) return obs - def get_obs(self): - return self._get_obs() - - def _get_reward(self, episode_end): - if not episode_end: + def _get_reward(self, terminated): + if not terminated: return 0 - else: - min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1)) - if not self._hit_ball: - return 0.2 * (1 - np.tanh(min_r_b_dist**2)) - else: - if self._ball_landing_pos is None: - min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1)) - return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2)) - else: - min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2]) - over_net_bonus = int(self._ball_landing_pos[0] < 0) - return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus + min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1)) + if not self._hit_ball: + return 0.2 * (1 - np.tanh(min_r_b_dist**2)) + if self._ball_landing_pos is None: + min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1)) + return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2)) + min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2]) + over_net_bonus = int(self._ball_landing_pos[0] < 0) + return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus def _generate_random_ball(self, random_pos=False, random_vel=False): x_pos, y_pos, z_pos = -0.5, 0.35, 1.75 @@ -227,7 +221,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def _generate_valid_init_ball(self, random_pos=False, random_vel=False): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) - while not check_init_state_validity(init_ball_state): + while not is_init_state_valid(init_ball_state): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py index 66f68d2..4d9a2d2 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py @@ -13,7 +13,7 @@ table_y_min = -0.6 table_y_max = 0.6 g = 9.81 -def check_init_state_validity(init_state): +def is_init_state_valid(init_state): assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state) x = init_state[0] y = init_state[1]