updates according to changes request

This commit is contained in:
hongyi.zhou 2023-01-27 17:50:14 +01:00
parent 67f684cf14
commit 344c11d67a
5 changed files with 103 additions and 78 deletions

3
.gitignore vendored
View File

@ -111,3 +111,6 @@ venv.bak/
/configs/db.cfg /configs/db.cfg
legacy/ legacy/
MUJOCO_LOG.TXT MUJOCO_LOG.TXT
# vscode
.vscode

View File

@ -159,11 +159,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict() infos = dict()
done = False done = False
if traj_is_valid is False: if not traj_is_valid:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity, obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation) self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
else:
self.plan_steps += 1 self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)): for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel) step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)

View File

@ -56,12 +56,30 @@ class RawInterfaceWrapper(gym.Wrapper):
-> Tuple[bool, np.ndarray, np.ndarray]: -> Tuple[bool, np.ndarray, np.ndarray]:
""" """
Used to preprocess the action and check if the desired trajectory is valid. Used to preprocess the action and check if the desired trajectory is valid.
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
validity flag: bool, True if the raw trajectory is valid, False if not
pos_traj: a vector instance of the preprocessed position trajectory
vel_traj: a vector instance of the preprocessed velocity trajectory
""" """
return True, pos_traj, vel_traj return True, pos_traj, vel_traj
def set_episode_arguments(self, action, pos_traj, vel_traj): def set_episode_arguments(self, action, pos_traj, vel_traj):
""" """
Used to set the arguments for env that valid for the whole episode Used to set the arguments for env that valid for the whole episode
deprecated, replaced by preprocessing_and_validity_callback
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
pos_traj: a vector instance of the preprocessed position trajectory
vel_traj: a vector instance of the preprocessed velocity trajectory
""" """
return pos_traj, vel_traj return pos_traj, vel_traj
@ -82,5 +100,15 @@ class RawInterfaceWrapper(gym.Wrapper):
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]: def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
""" """
Used to return a artificial return from the env if the desired trajectory is invalid. Used to return a artificial return from the env if the desired trajectory is invalid.
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
obs: artificial observation if the trajectory is invalid, by default a zero vector
reward: artificial reward if the trajectory is invalid, by default 0
done: artificial done if the trajectory is invalid, by default True
info: artificial info if the trajectory is invalid, by default empty dict
""" """
return np.zeros(1), 0, True, {} return np.zeros(1), 0, True, {}

View File

@ -4,7 +4,7 @@ import numpy as np
from gym import utils, spaces from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv from gym.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
import mujoco import mujoco
@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_return_success = False self._ball_return_success = False
self._ball_landing_pos = None self._ball_landing_pos = None
self._init_ball_state = None self._init_ball_state = None
self._episode_end = False self._terminated = False
self._id_set = False self._id_set = False
@ -54,6 +54,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,
mujoco_bindings="mujoco") mujoco_bindings="mujoco")
if ctxt_dim == 2: if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4: elif ctxt_dim == 4:
@ -81,7 +82,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5: if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
new_goal_pos = self._generate_goal_pos(random=True) new_goal_pos = self._generate_goal_pos(random=True)
new_goal_pos[1] = -new_goal_pos[1] new_goal_pos[1] = -new_goal_pos[1]
self._goal_pos = new_goal_pos self._goal_pos = new_goal_pos
@ -96,7 +97,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
except Exception as e: except Exception as e:
print("Simulation get unstable return with MujocoException: ", e) print("Simulation get unstable return with MujocoException: ", e)
unstable_simulation = True unstable_simulation = True
self._episode_end = True self._terminated = True
break break
if not self._hit_ball: if not self._hit_ball:
@ -106,33 +107,32 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id) ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
if ball_land_on_floor_no_hit: if ball_land_on_floor_no_hit:
self._ball_landing_pos = self.data.body("target_ball").xpos.copy() self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
self._episode_end = True self._terminated = True
if self._hit_ball and not self._ball_contact_after_hit: if self._hit_ball and not self._ball_contact_after_hit:
if not self._ball_contact_after_hit:
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
self._ball_contact_after_hit = True self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
self._episode_end = True self._terminated = True
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
self._ball_contact_after_hit = True self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy() self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
self._ball_return_success = True self._ball_return_success = True
self._episode_end = True self._terminated = True
# update ball trajectory & racket trajectory # update ball trajectory & racket trajectory
self._ball_traj.append(self.data.body("target_ball").xpos.copy()) self._ball_traj.append(self.data.body("target_ball").xpos.copy())
self._racket_traj.append(self.data.geom("bat").xpos.copy()) self._racket_traj.append(self.data.geom("bat").xpos.copy())
self._steps += 1 self._steps += 1
self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end self._terminated = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._terminated
reward = -25 if unstable_simulation else self._get_reward(self._episode_end) reward = -25 if unstable_simulation else self._get_reward(self._terminated)
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10. if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._episode_end, { return self._get_obs(), reward, self._terminated, {
"hit_ball": self._hit_ball, "hit_ball": self._hit_ball,
"ball_returned_success": self._ball_return_success, "ball_returned_success": self._ball_return_success,
"land_dist_error": land_dist_err, "land_dist_error": land_dist_err,
@ -173,7 +173,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_contact_after_hit = False self._ball_contact_after_hit = False
self._ball_return_success = False self._ball_return_success = False
self._ball_landing_pos = None self._ball_landing_pos = None
self._episode_end = False self._terminated = False
self._ball_traj = [] self._ball_traj = []
self._racket_traj = [] self._racket_traj = []
return self._get_obs() return self._get_obs()
@ -195,21 +195,15 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
]) ])
return obs return obs
def get_obs(self): def _get_reward(self, terminated):
return self._get_obs() if not terminated:
def _get_reward(self, episode_end):
if not episode_end:
return 0 return 0
else:
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1)) min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
if not self._hit_ball: if not self._hit_ball:
return 0.2 * (1 - np.tanh(min_r_b_dist**2)) return 0.2 * (1 - np.tanh(min_r_b_dist**2))
else:
if self._ball_landing_pos is None: if self._ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1)) min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2)) return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
else:
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2]) min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
over_net_bonus = int(self._ball_landing_pos[0] < 0) over_net_bonus = int(self._ball_landing_pos[0] < 0)
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
@ -227,7 +221,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def _generate_valid_init_ball(self, random_pos=False, random_vel=False): def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
while not check_init_state_validity(init_ball_state): while not is_init_state_valid(init_ball_state):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state return init_ball_state

View File

@ -13,7 +13,7 @@ table_y_min = -0.6
table_y_max = 0.6 table_y_max = 0.6
g = 9.81 g = 9.81
def check_init_state_validity(init_state): def is_init_state_valid(init_state):
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state) assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
x = init_state[0] x = init_state[0]
y = init_state[1] y = init_state[1]