updates according to changes request
This commit is contained in:
parent
67f684cf14
commit
344c11d67a
3
.gitignore
vendored
3
.gitignore
vendored
@ -111,3 +111,6 @@ venv.bak/
|
|||||||
/configs/db.cfg
|
/configs/db.cfg
|
||||||
legacy/
|
legacy/
|
||||||
MUJOCO_LOG.TXT
|
MUJOCO_LOG.TXT
|
||||||
|
|
||||||
|
# vscode
|
||||||
|
.vscode
|
||||||
|
@ -159,11 +159,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
if traj_is_valid is False:
|
if not traj_is_valid:
|
||||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||||
self.return_context_observation)
|
self.return_context_observation)
|
||||||
return self.observation(obs), trajectory_return, done, infos
|
return self.observation(obs), trajectory_return, done, infos
|
||||||
else:
|
|
||||||
self.plan_steps += 1
|
self.plan_steps += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||||
|
@ -56,12 +56,30 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
-> Tuple[bool, np.ndarray, np.ndarray]:
|
-> Tuple[bool, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Used to preprocess the action and check if the desired trajectory is valid.
|
Used to preprocess the action and check if the desired trajectory is valid.
|
||||||
|
Args:
|
||||||
|
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||||
|
specified, else only traj_gen parameters
|
||||||
|
pos_traj: a vector instance of the raw position trajectory
|
||||||
|
vel_traj: a vector instance of the raw velocity trajectory
|
||||||
|
Returns:
|
||||||
|
validity flag: bool, True if the raw trajectory is valid, False if not
|
||||||
|
pos_traj: a vector instance of the preprocessed position trajectory
|
||||||
|
vel_traj: a vector instance of the preprocessed velocity trajectory
|
||||||
"""
|
"""
|
||||||
return True, pos_traj, vel_traj
|
return True, pos_traj, vel_traj
|
||||||
|
|
||||||
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
||||||
"""
|
"""
|
||||||
Used to set the arguments for env that valid for the whole episode
|
Used to set the arguments for env that valid for the whole episode
|
||||||
|
deprecated, replaced by preprocessing_and_validity_callback
|
||||||
|
Args:
|
||||||
|
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||||
|
specified, else only traj_gen parameters
|
||||||
|
pos_traj: a vector instance of the raw position trajectory
|
||||||
|
vel_traj: a vector instance of the raw velocity trajectory
|
||||||
|
Returns:
|
||||||
|
pos_traj: a vector instance of the preprocessed position trajectory
|
||||||
|
vel_traj: a vector instance of the preprocessed velocity trajectory
|
||||||
"""
|
"""
|
||||||
return pos_traj, vel_traj
|
return pos_traj, vel_traj
|
||||||
|
|
||||||
@ -82,5 +100,15 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||||
"""
|
"""
|
||||||
Used to return a artificial return from the env if the desired trajectory is invalid.
|
Used to return a artificial return from the env if the desired trajectory is invalid.
|
||||||
|
Args:
|
||||||
|
action: a vector instance of the whole action space, includes traj_gen parameters and additional parameters if
|
||||||
|
specified, else only traj_gen parameters
|
||||||
|
pos_traj: a vector instance of the raw position trajectory
|
||||||
|
vel_traj: a vector instance of the raw velocity trajectory
|
||||||
|
Returns:
|
||||||
|
obs: artificial observation if the trajectory is invalid, by default a zero vector
|
||||||
|
reward: artificial reward if the trajectory is invalid, by default 0
|
||||||
|
done: artificial done if the trajectory is invalid, by default True
|
||||||
|
info: artificial info if the trajectory is invalid, by default empty dict
|
||||||
"""
|
"""
|
||||||
return np.zeros(1), 0, True, {}
|
return np.zeros(1), 0, True, {}
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from gym import utils, spaces
|
from gym import utils, spaces
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
|
||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||||
|
|
||||||
import mujoco
|
import mujoco
|
||||||
@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self._ball_return_success = False
|
self._ball_return_success = False
|
||||||
self._ball_landing_pos = None
|
self._ball_landing_pos = None
|
||||||
self._init_ball_state = None
|
self._init_ball_state = None
|
||||||
self._episode_end = False
|
self._terminated = False
|
||||||
|
|
||||||
self._id_set = False
|
self._id_set = False
|
||||||
|
|
||||||
@ -54,6 +54,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
mujoco_bindings="mujoco")
|
mujoco_bindings="mujoco")
|
||||||
|
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
elif ctxt_dim == 4:
|
elif ctxt_dim == 4:
|
||||||
@ -81,7 +82,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
|
if self._steps == self._goal_switching_step and self.np_random.uniform() < 0.5:
|
||||||
new_goal_pos = self._generate_goal_pos(random=True)
|
new_goal_pos = self._generate_goal_pos(random=True)
|
||||||
new_goal_pos[1] = -new_goal_pos[1]
|
new_goal_pos[1] = -new_goal_pos[1]
|
||||||
self._goal_pos = new_goal_pos
|
self._goal_pos = new_goal_pos
|
||||||
@ -96,7 +97,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Simulation get unstable return with MujocoException: ", e)
|
print("Simulation get unstable return with MujocoException: ", e)
|
||||||
unstable_simulation = True
|
unstable_simulation = True
|
||||||
self._episode_end = True
|
self._terminated = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self._hit_ball:
|
if not self._hit_ball:
|
||||||
@ -106,33 +107,32 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
ball_land_on_floor_no_hit = self._contact_checker(self._ball_contact_id, self._floor_contact_id)
|
||||||
if ball_land_on_floor_no_hit:
|
if ball_land_on_floor_no_hit:
|
||||||
self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
self._ball_landing_pos = self.data.body("target_ball").xpos.copy()
|
||||||
self._episode_end = True
|
self._terminated = True
|
||||||
if self._hit_ball and not self._ball_contact_after_hit:
|
if self._hit_ball and not self._ball_contact_after_hit:
|
||||||
if not self._ball_contact_after_hit:
|
|
||||||
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
if self._contact_checker(self._ball_contact_id, self._floor_contact_id): # first check contact with floor
|
||||||
self._ball_contact_after_hit = True
|
self._ball_contact_after_hit = True
|
||||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||||
self._episode_end = True
|
self._terminated = True
|
||||||
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
elif self._contact_checker(self._ball_contact_id, self._table_contact_id): # second check contact with table
|
||||||
self._ball_contact_after_hit = True
|
self._ball_contact_after_hit = True
|
||||||
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
self._ball_landing_pos = self.data.geom("target_ball_contact").xpos.copy()
|
||||||
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
if self._ball_landing_pos[0] < 0.: # ball lands on the opponent side
|
||||||
self._ball_return_success = True
|
self._ball_return_success = True
|
||||||
self._episode_end = True
|
self._terminated = True
|
||||||
|
|
||||||
# update ball trajectory & racket trajectory
|
# update ball trajectory & racket trajectory
|
||||||
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
|
self._ball_traj.append(self.data.body("target_ball").xpos.copy())
|
||||||
self._racket_traj.append(self.data.geom("bat").xpos.copy())
|
self._racket_traj.append(self.data.geom("bat").xpos.copy())
|
||||||
|
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
self._episode_end = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._episode_end
|
self._terminated = True if self._steps >= MAX_EPISODE_STEPS_TABLE_TENNIS else self._terminated
|
||||||
|
|
||||||
reward = -25 if unstable_simulation else self._get_reward(self._episode_end)
|
reward = -25 if unstable_simulation else self._get_reward(self._terminated)
|
||||||
|
|
||||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||||
if self._ball_landing_pos is not None else 10.
|
if self._ball_landing_pos is not None else 10.
|
||||||
|
|
||||||
return self._get_obs(), reward, self._episode_end, {
|
return self._get_obs(), reward, self._terminated, {
|
||||||
"hit_ball": self._hit_ball,
|
"hit_ball": self._hit_ball,
|
||||||
"ball_returned_success": self._ball_return_success,
|
"ball_returned_success": self._ball_return_success,
|
||||||
"land_dist_error": land_dist_err,
|
"land_dist_error": land_dist_err,
|
||||||
@ -173,7 +173,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self._ball_contact_after_hit = False
|
self._ball_contact_after_hit = False
|
||||||
self._ball_return_success = False
|
self._ball_return_success = False
|
||||||
self._ball_landing_pos = None
|
self._ball_landing_pos = None
|
||||||
self._episode_end = False
|
self._terminated = False
|
||||||
self._ball_traj = []
|
self._ball_traj = []
|
||||||
self._racket_traj = []
|
self._racket_traj = []
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
@ -195,21 +195,15 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
])
|
])
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def get_obs(self):
|
def _get_reward(self, terminated):
|
||||||
return self._get_obs()
|
if not terminated:
|
||||||
|
|
||||||
def _get_reward(self, episode_end):
|
|
||||||
if not episode_end:
|
|
||||||
return 0
|
return 0
|
||||||
else:
|
|
||||||
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
|
min_r_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj) - np.array(self._racket_traj), axis=1))
|
||||||
if not self._hit_ball:
|
if not self._hit_ball:
|
||||||
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
return 0.2 * (1 - np.tanh(min_r_b_dist**2))
|
||||||
else:
|
|
||||||
if self._ball_landing_pos is None:
|
if self._ball_landing_pos is None:
|
||||||
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self._ball_traj)[:,:2] - self._goal_pos[:2], axis=1))
|
||||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + (1 - np.tanh(min_b_des_b_dist**2))
|
||||||
else:
|
|
||||||
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
min_b_des_b_land_dist = np.linalg.norm(self._goal_pos[:2] - self._ball_landing_pos[:2])
|
||||||
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
over_net_bonus = int(self._ball_landing_pos[0] < 0)
|
||||||
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
|
return 2 * (1 - np.tanh(min_r_b_dist ** 2)) + 4 * (1 - np.tanh(min_b_des_b_land_dist ** 2)) + over_net_bonus
|
||||||
@ -227,7 +221,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
|
def _generate_valid_init_ball(self, random_pos=False, random_vel=False):
|
||||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
while not check_init_state_validity(init_ball_state):
|
while not is_init_state_valid(init_ball_state):
|
||||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||||
return init_ball_state
|
return init_ball_state
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ table_y_min = -0.6
|
|||||||
table_y_max = 0.6
|
table_y_max = 0.6
|
||||||
g = 9.81
|
g = 9.81
|
||||||
|
|
||||||
def check_init_state_validity(init_state):
|
def is_init_state_valid(init_state):
|
||||||
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
|
assert len(init_state) == 6, "init_state must be a 6D vector (pos+vel),got {}".format(init_state)
|
||||||
x = init_state[0]
|
x = init_state[0]
|
||||||
y = init_state[1]
|
y = init_state[1]
|
||||||
|
Loading…
Reference in New Issue
Block a user