updated table tennis and beerpong for promp usage
This commit is contained in:
parent
083e937e17
commit
a0af743585
@ -198,14 +198,19 @@ register(
|
|||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
register(id='TableTennis2DCtxt-v0',
|
register(id='TableTennis2DCtxt-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:TT_Env_Gym',
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS,
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
kwargs={'ctxt_dim':2})
|
kwargs={'ctxt_dim': 2})
|
||||||
|
|
||||||
|
register(id='TableTennis2DCtxt-v1',
|
||||||
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
|
max_episode_steps=1750,
|
||||||
|
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
||||||
|
|
||||||
register(id='TableTennis4DCtxt-v0',
|
register(id='TableTennis4DCtxt-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:TT_Env_Gym',
|
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS,
|
max_episode_steps=MAX_EPISODE_STEPS,
|
||||||
kwargs={'ctxt_dim':4})
|
kwargs={'ctxt_dim': 4})
|
||||||
|
|
||||||
## BeerPong
|
## BeerPong
|
||||||
difficulties = ["simple", "intermediate", "hard", "hardest"]
|
difficulties = ["simple", "intermediate", "hard", "hardest"]
|
||||||
@ -369,13 +374,10 @@ register(
|
|||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 7,
|
"num_dof": 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"n_zero_bases": 2,
|
"duration": 1,
|
||||||
"duration": 0.5,
|
"post_traj_time": 2,
|
||||||
"post_traj_time": 2.5,
|
|
||||||
# "width": 0.01,
|
|
||||||
# "off": 0.01,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.08,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"zero_goal": False,
|
"zero_goal": False,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
@ -388,20 +390,20 @@ register(
|
|||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("BeerpongProMP-v0")
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("BeerpongProMP-v0")
|
||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
register(
|
ctxt_dim = [2, 4]
|
||||||
id='TableTennisProMP-v0',
|
for _v, cd in enumerate(ctxt_dim):
|
||||||
|
_env_id = f'TableTennisProMP-v{_v}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:TableTennis4DCtxt-v0",
|
"name": "alr_envs:TableTennis{}DCtxt-v0".format(cd),
|
||||||
"wrappers": [mujoco.table_tennis.MPWrapper],
|
"wrappers": [mujoco.table_tennis.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 7,
|
"num_dof": 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"n_zero_bases": 2,
|
|
||||||
"duration": 1.25,
|
"duration": 1.25,
|
||||||
"post_traj_time": 4.5,
|
"post_traj_time": 4.5,
|
||||||
# "width": 0.01,
|
|
||||||
# "off": 0.01,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1.0,
|
"weights_scale": 1.0,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
@ -412,5 +414,29 @@ register(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='TableTennisProMP-v2',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:TableTennis2DCtxt-v1",
|
||||||
|
"wrappers": [mujoco.table_tennis.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 7,
|
||||||
|
"num_basis": 2,
|
||||||
|
"duration": 1.,
|
||||||
|
"post_traj_time": 2.5,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"zero_goal": False,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 0.5*np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
|
||||||
|
"d_gains": 0.5*np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v0")
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v2")
|
||||||
|
@ -14,8 +14,5 @@
|
|||||||
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|
||||||
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|
||||||
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|
||||||
|`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupSimple-v0` task where only 3 joints are actuated. | 4000 | 15
|
|
||||||
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|
|
||||||
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
|
|
||||||
|
|
||||||
[//]: |`HoleReacherProMPP-v0`|
|
[//]: |`HoleReacherProMPP-v0`|
|
@ -2,5 +2,5 @@ from .reacher.alr_reacher import ALRReacherEnv
|
|||||||
from .reacher.balancing import BalancingEnv
|
from .reacher.balancing import BalancingEnv
|
||||||
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||||
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||||
from .table_tennis.tt_gym import TT_Env_Gym
|
from .table_tennis.tt_gym import TTEnvGym
|
||||||
from .beerpong.beerpong import ALRBeerBongEnv
|
from .beerpong.beerpong import ALRBeerBongEnv
|
@ -27,10 +27,10 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.ball_site_id = 0
|
self.ball_site_id = 0
|
||||||
self.ball_id = 11
|
self.ball_id = 11
|
||||||
|
|
||||||
self._release_step = 100 # time step of ball release
|
self._release_step = 175 # time step of ball release
|
||||||
|
|
||||||
self.sim_time = 4 # seconds
|
self.sim_time = 3 # seconds
|
||||||
self.ep_length = 600 # based on 5 seconds with dt = 0.005 int(self.sim_time / self.dt)
|
self.ep_length = 600 # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt)
|
||||||
self.cup_table_id = 10
|
self.cup_table_id = 10
|
||||||
|
|
||||||
if noisy:
|
if noisy:
|
||||||
@ -143,7 +143,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
q_vel=self.sim.data.qvel[0:7].ravel().copy(),
|
q_vel=self.sim.data.qvel[0:7].ravel().copy(),
|
||||||
ball_pos=ball_pos,
|
ball_pos=ball_pos,
|
||||||
ball_vel=ball_vel,
|
ball_vel=ball_vel,
|
||||||
is_success=success,
|
success=success,
|
||||||
is_collided=is_collided, sim_crash=crash)
|
is_collided=is_collided, sim_crash=crash)
|
||||||
|
|
||||||
def check_traj_in_joint_limits(self):
|
def check_traj_in_joint_limits(self):
|
||||||
@ -171,7 +171,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = ALRBeerBongEnv(reward_type="no_context", difficulty='hardest')
|
env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest')
|
||||||
|
|
||||||
# env.configure(ctxt)
|
# env.configure(ctxt)
|
||||||
env.reset()
|
env.reset()
|
||||||
|
@ -71,6 +71,7 @@ class BeerPongReward:
|
|||||||
|
|
||||||
goal_pos = env.sim.data.site_xpos[self.goal_id]
|
goal_pos = env.sim.data.site_xpos[self.goal_id]
|
||||||
ball_pos = env.sim.data.body_xpos[self.ball_id]
|
ball_pos = env.sim.data.body_xpos[self.ball_id]
|
||||||
|
ball_vel = env.sim.data.body_xvelp[self.ball_id]
|
||||||
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
|
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
|
||||||
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
||||||
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
||||||
@ -131,6 +132,7 @@ class BeerPongReward:
|
|||||||
infos["success"] = success
|
infos["success"] = success
|
||||||
infos["is_collided"] = self._is_collided
|
infos["is_collided"] = self._is_collided
|
||||||
infos["ball_pos"] = ball_pos.copy()
|
infos["ball_pos"] = ball_pos.copy()
|
||||||
|
infos["ball_vel"] = ball_vel.copy()
|
||||||
infos["action_cost"] = 5e-4 * action_cost
|
infos["action_cost"] = 5e-4 * action_cost
|
||||||
|
|
||||||
return reward, infos
|
return reward, infos
|
||||||
|
@ -81,32 +81,36 @@ class BeerPongReward:
|
|||||||
action_cost = np.sum(np.square(action))
|
action_cost = np.sum(np.square(action))
|
||||||
self.action_costs.append(action_cost)
|
self.action_costs.append(action_cost)
|
||||||
|
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
self.ball_table_contact = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
||||||
|
self.table_collision_id)
|
||||||
|
|
||||||
self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids)
|
self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids)
|
||||||
if env._steps == env.ep_length - 1 or self._is_collided:
|
if env._steps == env.ep_length - 1 or self._is_collided:
|
||||||
|
|
||||||
min_dist = np.min(self.dists)
|
min_dist = np.min(self.dists)
|
||||||
ball_table_bounce = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
final_dist = self.dists_final[-1]
|
||||||
self.table_collision_id)
|
|
||||||
ball_cup_table_cont = self._check_collision_with_set_of_objects(env.sim, self.ball_collision_id,
|
|
||||||
self.cup_collision_ids)
|
|
||||||
ball_wall_cont = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
|
||||||
self.wall_collision_id)
|
|
||||||
ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id,
|
||||||
self.cup_table_collision_id)
|
self.cup_table_collision_id)
|
||||||
if not ball_in_cup:
|
|
||||||
cost_offset = 2
|
|
||||||
if not ball_cup_table_cont and not ball_table_bounce and not ball_wall_cont:
|
|
||||||
cost_offset += 2
|
|
||||||
cost = cost_offset + min_dist ** 2 + 0.5 * self.dists_final[-1] ** 2 + 1e-7 * action_cost
|
|
||||||
else:
|
|
||||||
cost = self.dists_final[-1] ** 2 + 1.5 * action_cost * 1e-7
|
|
||||||
|
|
||||||
reward = - 1 * cost - self.collision_penalty * int(self._is_collided)
|
# encourage bounce before falling into cup
|
||||||
|
if not ball_in_cup:
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
reward = 0.2 * (1 - np.tanh(min_dist ** 2)) + 0.1 * (1 - np.tanh(final_dist ** 2))
|
||||||
|
else:
|
||||||
|
reward = (1 - np.tanh(min_dist ** 2)) + 0.5 * (1 - np.tanh(final_dist ** 2))
|
||||||
|
else:
|
||||||
|
if not self.ball_table_contact:
|
||||||
|
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 1
|
||||||
|
else:
|
||||||
|
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 3
|
||||||
|
|
||||||
|
# reward = - 1 * cost - self.collision_penalty * int(self._is_collided)
|
||||||
success = ball_in_cup
|
success = ball_in_cup
|
||||||
crash = self._is_collided
|
crash = self._is_collided
|
||||||
else:
|
else:
|
||||||
reward = - 1e-7 * action_cost
|
reward = - 1e-4 * action_cost
|
||||||
cost = 0
|
|
||||||
success = False
|
success = False
|
||||||
crash = False
|
crash = False
|
||||||
|
|
||||||
@ -115,26 +119,11 @@ class BeerPongReward:
|
|||||||
infos["is_collided"] = self._is_collided
|
infos["is_collided"] = self._is_collided
|
||||||
infos["ball_pos"] = ball_pos.copy()
|
infos["ball_pos"] = ball_pos.copy()
|
||||||
infos["ball_vel"] = ball_vel.copy()
|
infos["ball_vel"] = ball_vel.copy()
|
||||||
infos["action_cost"] = 5e-4 * action_cost
|
infos["action_cost"] = action_cost
|
||||||
infos["task_cost"] = cost
|
infos["task_reward"] = reward
|
||||||
|
|
||||||
return reward, infos
|
return reward, infos
|
||||||
|
|
||||||
def get_cost_offset(self):
|
|
||||||
if self.ball_ground_contact:
|
|
||||||
return 200
|
|
||||||
|
|
||||||
if not self.ball_table_contact:
|
|
||||||
return 100
|
|
||||||
|
|
||||||
if not self.ball_in_cup:
|
|
||||||
return 50
|
|
||||||
|
|
||||||
if self.ball_in_cup and self.ball_cup_contact and not self.noisy_bp:
|
|
||||||
return 10
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _check_collision_single_objects(self, sim, id_1, id_2):
|
def _check_collision_single_objects(self, sim, id_1, id_2):
|
||||||
for coni in range(0, sim.data.ncon):
|
for coni in range(0, sim.data.ncon):
|
||||||
con = sim.data.contact[coni]
|
con = sim.data.contact[coni]
|
||||||
|
@ -6,8 +6,6 @@ from gym.envs.mujoco import MujocoEnv
|
|||||||
|
|
||||||
class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None):
|
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None):
|
||||||
utils.EzPickle.__init__(**locals())
|
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||||
@ -28,15 +26,13 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self.context = None
|
self.context = None
|
||||||
|
|
||||||
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
|
|
||||||
|
|
||||||
# alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
# alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
||||||
# self.xml_path,
|
# self.xml_path,
|
||||||
# apply_gravity_comp=apply_gravity_comp,
|
# apply_gravity_comp=apply_gravity_comp,
|
||||||
# n_substeps=n_substeps)
|
# n_substeps=n_substeps)
|
||||||
|
|
||||||
self.sim_time = 8 # seconds
|
self.sim_time = 8 # seconds
|
||||||
self.sim_steps = int(self.sim_time / self.dt)
|
# self.sim_steps = int(self.sim_time / self.dt)
|
||||||
if reward_function is None:
|
if reward_function is None:
|
||||||
from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward
|
from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward
|
||||||
reward_function = BeerpongReward
|
reward_function = BeerpongReward
|
||||||
@ -46,6 +42,9 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.cup_table_id = self.sim.model._body_name2id["cup_table"]
|
self.cup_table_id = self.sim.model._body_name2id["cup_table"]
|
||||||
# self.bounce_table_id = self.sim.model._body_name2id["bounce_table"]
|
# self.bounce_table_id = self.sim.model._body_name2id["bounce_table"]
|
||||||
|
|
||||||
|
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
|
||||||
|
utils.EzPickle.__init__(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self):
|
def current_pos(self):
|
||||||
return self.sim.data.qpos[0:7].copy()
|
return self.sim.data.qpos[0:7].copy()
|
||||||
@ -90,7 +89,7 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - np.square(a).sum()
|
||||||
action_cost = np.sum(np.square(a))
|
action_cost = np.sum(np.square(a))
|
||||||
|
|
||||||
crash = self.do_simulation(a)
|
crash = self.do_simulation(a, self.frame_skip)
|
||||||
joint_cons_viol = self.check_traj_in_joint_limits()
|
joint_cons_viol = self.check_traj_in_joint_limits()
|
||||||
|
|
||||||
self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy())
|
self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy())
|
||||||
|
@ -10,7 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
|
|||||||
|
|
||||||
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
||||||
|
|
||||||
MAX_EPISODE_STEPS = 1375
|
MAX_EPISODE_STEPS = 2875
|
||||||
BALL_NAME_CONTACT = "target_ball_contact"
|
BALL_NAME_CONTACT = "target_ball_contact"
|
||||||
BALL_NAME = "target_ball"
|
BALL_NAME = "target_ball"
|
||||||
TABLE_NAME = 'table_tennis_table'
|
TABLE_NAME = 'table_tennis_table'
|
||||||
@ -22,14 +22,19 @@ RACKET_NAME = 'bat'
|
|||||||
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
|
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
|
||||||
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
|
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
|
||||||
|
|
||||||
class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|
||||||
|
|
||||||
def __init__(self, ctxt_dim=2):
|
class TTEnvGym(MujocoEnv, utils.EzPickle):
|
||||||
|
|
||||||
|
def __init__(self, ctxt_dim=2, fixed_goal=False):
|
||||||
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
|
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
|
||||||
|
|
||||||
self.ctxt_dim = ctxt_dim
|
self.ctxt_dim = ctxt_dim
|
||||||
|
self.fixed_goal = fixed_goal
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
|
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
|
||||||
|
if self.fixed_goal:
|
||||||
|
self.goal = np.array([-1, -0.1, 0])
|
||||||
|
else:
|
||||||
self.goal = np.zeros(3) # 2 x,y + 1z
|
self.goal = np.zeros(3) # 2 x,y + 1z
|
||||||
elif ctxt_dim == 4:
|
elif ctxt_dim == 4:
|
||||||
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
|
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
|
||||||
@ -47,10 +52,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self.reward_func = TT_Reward(self.ctxt_dim)
|
self.reward_func = TT_Reward(self.ctxt_dim)
|
||||||
self.ball_landing_pos = None
|
self.ball_landing_pos = None
|
||||||
self.hited_ball = False
|
self.hit_ball = False
|
||||||
self.ball_contact_after_hit = False
|
self.ball_contact_after_hit = False
|
||||||
self._ids_set = False
|
self._ids_set = False
|
||||||
super(TT_Env_Gym, self).__init__(model_path=model_path, frame_skip=1)
|
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
|
||||||
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
||||||
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
||||||
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
||||||
@ -77,14 +82,17 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def sample_context(self):
|
def sample_context(self):
|
||||||
return np.random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
|
return self.np_random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
|
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
|
||||||
self.time_steps = 0
|
self.time_steps = 0
|
||||||
self.ball_landing_pos = None
|
self.ball_landing_pos = None
|
||||||
self.hited_ball = False
|
self.hit_ball = False
|
||||||
self.ball_contact_after_hit = False
|
self.ball_contact_after_hit = False
|
||||||
|
if self.fixed_goal:
|
||||||
|
self.goal = self.goal[:2]
|
||||||
|
else:
|
||||||
self.goal = self.sample_context()[:2]
|
self.goal = self.sample_context()[:2]
|
||||||
if self.ctxt_dim == 2:
|
if self.ctxt_dim == 2:
|
||||||
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
|
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
|
||||||
@ -122,12 +130,12 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
if not self._ids_set:
|
if not self._ids_set:
|
||||||
self._set_ids()
|
self._set_ids()
|
||||||
done = False
|
done = False
|
||||||
episode_end = False if self.time_steps+1<MAX_EPISODE_STEPS else True
|
episode_end = False if self.time_steps + 1 < MAX_EPISODE_STEPS else True
|
||||||
if not self.hited_ball:
|
if not self.hit_ball:
|
||||||
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
|
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
|
||||||
if not self.hited_ball:
|
if not self.hit_ball:
|
||||||
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
|
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
|
||||||
if self.hited_ball:
|
if self.hit_ball:
|
||||||
if not self.ball_contact_after_hit:
|
if not self.ball_contact_after_hit:
|
||||||
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
|
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
|
||||||
self.ball_contact_after_hit = True
|
self.ball_contact_after_hit = True
|
||||||
@ -140,7 +148,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
if self.ball_landing_pos is not None:
|
if self.ball_landing_pos is not None:
|
||||||
done = True
|
done = True
|
||||||
episode_end =True
|
episode_end =True
|
||||||
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hited_ball, self.ball_landing_pos)
|
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hit_ball, self.ball_landing_pos)
|
||||||
self.time_steps += 1
|
self.time_steps += 1
|
||||||
# gravity compensation on joints:
|
# gravity compensation on joints:
|
||||||
#action += self.sim.data.qfrc_bias[:7].copy()
|
#action += self.sim.data.qfrc_bias[:7].copy()
|
||||||
@ -151,7 +159,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
|
|||||||
done = True
|
done = True
|
||||||
reward = -25
|
reward = -25
|
||||||
ob = self._get_obs()
|
ob = self._get_obs()
|
||||||
return ob, reward, done, {"hit_ball":self.hited_ball}# might add some information here ....
|
return ob, reward, done, {"hit_ball": self.hit_ball} # might add some information here ....
|
||||||
|
|
||||||
def set_context(self, context):
|
def set_context(self, context):
|
||||||
old_state = self.sim.get_state()
|
old_state = self.sim.get_state()
|
||||||
|
Loading…
Reference in New Issue
Block a user