updates according to changes request

This commit is contained in:
hongyi.zhou 2023-01-27 17:50:14 +01:00
parent 67f684cf14
commit 344c11d67a
5 changed files with 103 additions and 78 deletions

3
.gitignore vendored
View File

@ -111,3 +111,6 @@ venv.bak/
/configs/db.cfg
legacy/
MUJOCO_LOG.TXT
# vscode
.vscode

View File

@ -159,52 +159,52 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict()
done = False
if traj_is_valid is False:
if not traj_is_valid:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos
else:
self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times):
self.condition_pos = pos if self.condition_on_desired else None
self.condition_vel = vel if self.condition_on_desired else None
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward
if self.verbose >= 2:
infos['positions'] = position
infos['velocities'] = velocity
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
actions[t, :] = c_action
observations[t, :] = obs
infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times):
self.condition_pos = pos if self.condition_on_desired else None
self.condition_vel = vel if self.condition_on_desired else None
break
infos.update({k: v[:t+1] for k, v in infos.items()})
self.current_traj_steps += t + 1
if self.verbose >= 2:
infos['positions'] = position
infos['velocities'] = velocity
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout.

View File

@ -56,12 +56,30 @@ class RawInterfaceWrapper(gym.Wrapper):
-> Tuple[bool, np.ndarray, np.ndarray]:
"""
Used to preprocess the action and check if the desired trajectory is valid.
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
validity flag: bool, True if the raw trajectory is valid, False if not
pos_traj: a vector instance of the preprocessed position trajectory
vel_traj: a vector instance of the preprocessed velocity trajectory
"""
return True, pos_traj, vel_traj
def set_episode_arguments(self, action, pos_traj, vel_traj):
"""
Used to set the arguments for env that valid for the whole episode
deprecated, replaced by preprocessing_and_validity_callback
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
pos_traj: a vector instance of the preprocessed position trajectory
vel_traj: a vector instance of the preprocessed velocity trajectory
"""
return pos_traj, vel_traj
@ -82,5 +100,15 @@ class RawInterfaceWrapper(gym.Wrapper):
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
"""
Used to return a artificial return from the env if the desired trajectory is invalid.
Args:
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
specified, else only traj_gen parameters
pos_traj: a vector instance of the raw position trajectory
vel_traj: a vector instance of the raw velocity trajectory
Returns:
obs: artificial observation if the trajectory is invalid, by default a zero vector
reward: artificial reward if the trajectory is invalid, by default 0
done: artificial done if the trajectory is invalid, by default True
info: artificial info if the trajectory is invalid, by default empty dict
"""
return np.zeros(1), 0, True, {}

View File

@ -4,7 +4,7 @@ import numpy as np
from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
import mujoco
@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_return_success = False
self._ball_landing_pos = None
self._init_ball_state = None
self._episode_end = False
self._terminated = False
self._id_set = False
@ -54,6 +54,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip,
mujoco_bindings="mujoco")
if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4:
@ -81,7 +82,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
new_goal_pos = self._generate_goal_pos(random=True)
new_goal_pos[1] = -new_goal_pos[1]
self._goal_pos = new_goal_pos
@ -96,7 +97,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
except Exception as e:
print("Simulation get unstable return with MujocoException: ", e)
unstable_simulation = True
self._episode_end = True
self._terminated = True
break
if not self._hit_ball:
@ -106,33 +107,32 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
if ball_land_on_floor_no_hit:
self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
self._episode_end = True
self._terminated = True
if self._hit_ball and not self._ball_contact_after_hit:
if not self._ball_contact_after_hit:
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
self._episode_end = True
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
self._ball_return_success = True
self._episode_end = True
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
self._terminated = True
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
self._ball_contact_after_hit = True
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
self._ball_return_success = True
self._terminated = True
# update ball trajectory & racket trajectory
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
self._racket_traj.append(self.data.geom("bat").xpos.copy())
self._steps += 1
self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end
self._terminated = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._terminated
reward = -25 if unstable_simulation else self._get_reward(self._episode_end)
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._episode_end, {
return self._get_obs(), reward, self._terminated, {
"hit_ball": self._hit_ball,
"ball_returned_success": self._ball_return_success,
"land_dist_error": land_dist_err,
@ -173,7 +173,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_contact_after_hit = False
self._ball_return_success = False
self._ball_landing_pos = None
self._episode_end = False
self._terminated = False
self._ball_traj = []
self._racket_traj = []
return self._get_obs()
@ -195,24 +195,18 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
])
return obs
def get_obs(self):
return self._get_obs()
def _get_reward(self, episode_end):
if not episode_end:
def _get_reward(self, terminated):
if not terminated:
return 0
else:
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
if not self._hit_ball:
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
else:
if self._ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
else:
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
over_net_bonus = int(self._ball_landing_pos[0] < 0)
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
if not self._hit_ball:
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
if self._ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
over_net_bonus = int(self._ball_landing_pos[0] < 0)
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
def _generate_random_ball(self, random_pos=False, random_vel=False):
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
@ -227,7 +221,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
while not check_init_state_validity(init_ball_state):
while not is_init_state_valid(init_ball_state):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state

View File

@ -13,7 +13,7 @@ table_y_min = -0.6
table_y_max = 0.6
g = 9.81
def check_init_state_validity(init_state):
def is_init_state_valid(init_state):
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
x = init_state[0]
y = init_state[1]