updates according to changes request
This commit is contained in:
parent
67f684cf14
commit
344c11d67a
3
.gitignore
vendored
3
.gitignore
vendored
@ -111,3 +111,6 @@ venv.bak/
|
||||
/configs/db.cfg
|
||||
legacy/
|
||||
MUJOCO_LOG.TXT
|
||||
|
||||
# vscode
|
||||
.vscode
|
||||
|
@ -159,52 +159,52 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
infos = dict()
|
||||
done = False
|
||||
|
||||
if traj_is_valid is False:
|
||||
if not traj_is_valid:
|
||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||
self.return_context_observation)
|
||||
return self.observation(obs), trajectory_return, done, infos
|
||||
else:
|
||||
self.plan_steps += 1
|
||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
obs, c_reward, done, info = self.env.step(c_action)
|
||||
rewards[t] = c_reward
|
||||
|
||||
if self.verbose >= 2:
|
||||
actions[t, :] = c_action
|
||||
observations[t, :] = obs
|
||||
|
||||
for k, v in info.items():
|
||||
elems = infos.get(k, [None] * trajectory_length)
|
||||
elems[t] = v
|
||||
infos[k] = elems
|
||||
|
||||
if self.render_kwargs:
|
||||
self.env.render(**self.render_kwargs)
|
||||
|
||||
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
|
||||
self.condition_pos = pos if self.condition_on_desired else None
|
||||
self.condition_vel = vel if self.condition_on_desired else None
|
||||
|
||||
break
|
||||
|
||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||
self.current_traj_steps += t + 1
|
||||
self.plan_steps += 1
|
||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
obs, c_reward, done, info = self.env.step(c_action)
|
||||
rewards[t] = c_reward
|
||||
|
||||
if self.verbose >= 2:
|
||||
infos['positions'] = position
|
||||
infos['velocities'] = velocity
|
||||
infos['step_actions'] = actions[:t + 1]
|
||||
infos['step_observations'] = observations[:t + 1]
|
||||
infos['step_rewards'] = rewards[:t + 1]
|
||||
actions[t, :] = c_action
|
||||
observations[t, :] = obs
|
||||
|
||||
infos['trajectory_length'] = t + 1
|
||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||
return self.observation(obs), trajectory_return, done, infos
|
||||
for k, v in info.items():
|
||||
elems = infos.get(k, [None] * trajectory_length)
|
||||
elems[t] = v
|
||||
infos[k] = elems
|
||||
|
||||
if self.render_kwargs:
|
||||
self.env.render(**self.render_kwargs)
|
||||
|
||||
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
|
||||
self.condition_pos = pos if self.condition_on_desired else None
|
||||
self.condition_vel = vel if self.condition_on_desired else None
|
||||
|
||||
break
|
||||
|
||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||
self.current_traj_steps += t + 1
|
||||
|
||||
if self.verbose >= 2:
|
||||
infos['positions'] = position
|
||||
infos['velocities'] = velocity
|
||||
infos['step_actions'] = actions[:t + 1]
|
||||
infos['step_observations'] = observations[:t + 1]
|
||||
infos['step_rewards'] = rewards[:t + 1]
|
||||
|
||||
infos['trajectory_length'] = t + 1
|
||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||
return self.observation(obs), trajectory_return, done, infos
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""Only set render options here, such that they can be used during the rollout.
|
||||
|
@ -56,12 +56,30 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
-> Tuple[bool, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Used to preprocess the action and check if the desired trajectory is valid.
|
||||
Args:
|
||||
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||
specified, else only traj_gen parameters
|
||||
pos_traj: a vector instance of the raw position trajectory
|
||||
vel_traj: a vector instance of the raw velocity trajectory
|
||||
Returns:
|
||||
validity flag: bool, True if the raw trajectory is valid, False if not
|
||||
pos_traj: a vector instance of the preprocessed position trajectory
|
||||
vel_traj: a vector instance of the preprocessed velocity trajectory
|
||||
"""
|
||||
return True, pos_traj, vel_traj
|
||||
|
||||
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
||||
"""
|
||||
Used to set the arguments for env that valid for the whole episode
|
||||
deprecated, replaced by preprocessing_and_validity_callback
|
||||
Args:
|
||||
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||
specified, else only traj_gen parameters
|
||||
pos_traj: a vector instance of the raw position trajectory
|
||||
vel_traj: a vector instance of the raw velocity trajectory
|
||||
Returns:
|
||||
pos_traj: a vector instance of the preprocessed position trajectory
|
||||
vel_traj: a vector instance of the preprocessed velocity trajectory
|
||||
"""
|
||||
return pos_traj, vel_traj
|
||||
|
||||
@ -82,5 +100,15 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
"""
|
||||
Used to return a artificial return from the env if the desired trajectory is invalid.
|
||||
Args:
|
||||
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||
specified, else only traj_gen parameters
|
||||
pos_traj: a vector instance of the raw position trajectory
|
||||
vel_traj: a vector instance of the raw velocity trajectory
|
||||
Returns:
|
||||
obs: artificial observation if the trajectory is invalid, by default a zero vector
|
||||
reward: artificial reward if the trajectory is invalid, by default 0
|
||||
done: artificial done if the trajectory is invalid, by default True
|
||||
info: artificial info if the trajectory is invalid, by default empty dict
|
||||
"""
|
||||
return np.zeros(1), 0, True, {}
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
from gym import utils, spaces
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||
|
||||
import mujoco
|
||||
@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
self._ball_return_success = False
|
||||
self._ball_landing_pos = None
|
||||
self._init_ball_state = None
|
||||
self._episode_end = False
|
||||
self._terminated = False
|
||||
|
||||
self._id_set = False
|
||||
|
||||
@ -54,6 +54,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||
frame_skip=frame_skip,
|
||||
mujoco_bindings="mujoco")
|
||||
|
||||
if ctxt_dim == 2:
|
||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||
elif ctxt_dim == 4:
|
||||
@ -81,7 +82,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
unstable_simulation = False
|
||||
|
||||
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
|
||||
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
|
||||
new_goal_pos = self._generate_goal_pos(random=True)
|
||||
new_goal_pos[1] = -new_goal_pos[1]
|
||||
self._goal_pos = new_goal_pos
|
||||
@ -96,7 +97,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
except Exception as e:
|
||||
print("Simulation get unstable return with MujocoException: ", e)
|
||||
unstable_simulation = True
|
||||
self._episode_end = True
|
||||
self._terminated = True
|
||||
break
|
||||
|
||||
if not self._hit_ball:
|
||||
@ -106,33 +107,32 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||
if ball_land_on_floor_no_hit:
|
||||
self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
||||
self._episode_end = True
|
||||
self._terminated = True
|
||||
if self._hit_ball and not self._ball_contact_after_hit:
|
||||
if not self._ball_contact_after_hit:
|
||||
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||
self._ball_contact_after_hit = True
|
||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||
self._episode_end = True
|
||||
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||
self._ball_contact_after_hit = True
|
||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||
self._ball_return_success = True
|
||||
self._episode_end = True
|
||||
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||
self._ball_contact_after_hit = True
|
||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||
self._terminated = True
|
||||
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||
self._ball_contact_after_hit = True
|
||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||
self._ball_return_success = True
|
||||
self._terminated = True
|
||||
|
||||
# update ball trajectory & racket trajectory
|
||||
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
|
||||
self._racket_traj.append(self.data.geom("bat").xpos.copy())
|
||||
|
||||
self._steps += 1
|
||||
self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end
|
||||
self._terminated = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._terminated
|
||||
|
||||
reward = -25 if unstable_simulation else self._get_reward(self._episode_end)
|
||||
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
|
||||
|
||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||
if self._ball_landing_pos is not None else 10.
|
||||
|
||||
return self._get_obs(), reward, self._episode_end, {
|
||||
return self._get_obs(), reward, self._terminated, {
|
||||
"hit_ball": self._hit_ball,
|
||||
"ball_returned_success": self._ball_return_success,
|
||||
"land_dist_error": land_dist_err,
|
||||
@ -173,7 +173,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
self._ball_contact_after_hit = False
|
||||
self._ball_return_success = False
|
||||
self._ball_landing_pos = None
|
||||
self._episode_end = False
|
||||
self._terminated = False
|
||||
self._ball_traj = []
|
||||
self._racket_traj = []
|
||||
return self._get_obs()
|
||||
@ -195,24 +195,18 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
])
|
||||
return obs
|
||||
|
||||
def get_obs(self):
|
||||
return self._get_obs()
|
||||
|
||||
def _get_reward(self, episode_end):
|
||||
if not episode_end:
|
||||
def _get_reward(self, terminated):
|
||||
if not terminated:
|
||||
return 0
|
||||
else:
|
||||
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
|
||||
if not self._hit_ball:
|
||||
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||
else:
|
||||
if self._ball_landing_pos is None:
|
||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||
else:
|
||||
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
|
||||
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
|
||||
if not self._hit_ball:
|
||||
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||
if self._ball_landing_pos is None:
|
||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
|
||||
|
||||
def _generate_random_ball(self, random_pos=False, random_vel=False):
|
||||
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
|
||||
@ -227,7 +221,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
while not check_init_state_validity(init_ball_state):
|
||||
while not is_init_state_valid(init_ball_state):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
return init_ball_state
|
||||
|
||||
|
@ -13,7 +13,7 @@ table_y_min = -0.6
|
||||
table_y_max = 0.6
|
||||
g = 9.81
|
||||
|
||||
def check_init_state_validity(init_state):
|
||||
def is_init_state_valid(init_state):
|
||||
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
|
||||
x = init_state[0]
|
||||
y = init_state[1]
|
||||
|
Loading…
Reference in New Issue
Block a user