fancy_gym/alr_envs/alr/mujoco/table_tennis/tt_gym.py
2021-12-06 13:43:45 +01:00

177 lines
8.3 KiB
Python

import os
import numpy as np
import mujoco_py
from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv
from alr_envs.alr.mujoco.table_tennis.tt_utils import ball_init
from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
MAX_EPISODE_STEPS = 2875
BALL_NAME_CONTACT = "target_ball_contact"
BALL_NAME = "target_ball"
TABLE_NAME = 'table_tennis_table'
FLOOR_NAME = 'floor'
PADDLE_CONTACT_1_NAME = 'bat'
PADDLE_CONTACT_2_NAME = 'bat_back'
RACKET_NAME = 'bat'
# CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.6]])
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
class TTEnvGym(MujocoEnv, utils.EzPickle):
def __init__(self, ctxt_dim=2, fixed_goal=False):
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
self.ctxt_dim = ctxt_dim
self.fixed_goal = fixed_goal
if ctxt_dim == 2:
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
if self.fixed_goal:
self.goal = np.array([-1, -0.1, 0])
else:
self.goal = np.zeros(3) # 2 x,y + 1z
elif ctxt_dim == 4:
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
self.goal = np.zeros(3)
else:
raise ValueError("either 2 or 4 dimensional Contexts available")
action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
self.time_steps = 0
self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0])
self.init_qvel_tt = np.zeros(10)
self.reward_func = TT_Reward(self.ctxt_dim)
self.ball_landing_pos = None
self.hit_ball = False
self.ball_contact_after_hit = False
self._ids_set = False
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
self.floor_contact_id = self.sim.model._geom_name2id[FLOOR_NAME]
self.paddle_contact_id_1 = self.sim.model._geom_name2id[PADDLE_CONTACT_1_NAME] # check if we need both or only this
self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this
self.racket_id = self.sim.model._geom_name2id[RACKET_NAME]
def _set_ids(self):
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
self.floor_contact_id = self.sim.model._geom_name2id[FLOOR_NAME]
self.paddle_contact_id_1 = self.sim.model._geom_name2id[PADDLE_CONTACT_1_NAME] # check if we need both or only this
self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this
self.racket_id = self.sim.model._geom_name2id[RACKET_NAME]
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
self._ids_set = True
def _get_obs(self):
ball_pos = self.sim.data.body_xpos[self.ball_id]
obs = np.concatenate([self.sim.data.qpos[:7].copy(), # 7 joint positions
ball_pos,
self.goal.copy()])
return obs
def sample_context(self):
return self.np_random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
def reset_model(self):
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
self.time_steps = 0
self.ball_landing_pos = None
self.hit_ball = False
self.ball_contact_after_hit = False
if self.fixed_goal:
self.goal = self.goal[:2]
else:
self.goal = self.sample_context()[:2]
if self.ctxt_dim == 2:
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
elif self.ctxt_dim == 4:
initial_ball_state = ball_init(random=False)#raise NotImplementedError
self.sim.data.set_joint_qpos('tar:x', initial_ball_state[0])
self.sim.data.set_joint_qpos('tar:y', initial_ball_state[1])
self.sim.data.set_joint_qpos('tar:z', initial_ball_state[2])
self.sim.data.set_joint_qvel('tar:x', initial_ball_state[3])
self.sim.data.set_joint_qvel('tar:y', initial_ball_state[4])
self.sim.data.set_joint_qvel('tar:z', initial_ball_state[5])
z_extended_goal_pos = np.concatenate((self.goal[:2], [0.77]))
self.goal = z_extended_goal_pos
self.sim.model.body_pos[5] = self.goal[:3] # Desired Landing Position, Yellow
self.sim.model.body_pos[3] = np.array([0, 0, 0.5]) # Outgoing Ball Landing Position, Green
self.sim.model.body_pos[4] = np.array([0, 0, 0.5]) # Incoming Ball Landing Position, Red
self.sim.forward()
self.reward_func.reset(self.goal) # reset the reward function
return self._get_obs()
def _contact_checker(self, id_1, id_2):
for coni in range(0, self.sim.data.ncon):
con = self.sim.data.contact[coni]
collision = con.geom1 == id_1 and con.geom2 == id_2
collision_trans = con.geom1 == id_2 and con.geom2 == id_1
if collision or collision_trans:
return True
return False
def step(self, action):
if not self._ids_set:
self._set_ids()
done = False
episode_end = False if self.time_steps + 1 < MAX_EPISODE_STEPS else True
if not self.hit_ball:
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
if not self.hit_ball:
self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
if self.hit_ball:
if not self.ball_contact_after_hit:
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
self.ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.body_xpos[self.ball_id]
elif self._contact_checker(self.ball_contact_id, self.table_contact_id): # second check contact with table
self.ball_contact_after_hit = True
self.ball_landing_pos = self.sim.data.body_xpos[self.ball_id]
c_ball_pos = self.sim.data.body_xpos[self.ball_id]
racket_pos = self.sim.data.geom_xpos[self.racket_id] # TODO: use this to reach out the position of the paddle?
if self.ball_landing_pos is not None:
done = True
episode_end =True
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hit_ball, self.ball_landing_pos)
self.time_steps += 1
# gravity compensation on joints:
#action += self.sim.data.qfrc_bias[:7].copy()
try:
self.do_simulation(action, self.frame_skip)
except mujoco_py.MujocoException as e:
print('Simulation got unstable returning')
done = True
reward = -25
ob = self._get_obs()
return ob, reward, done, {"hit_ball": self.hit_ball} # might add some information here ....
def set_context(self, context):
old_state = self.sim.get_state()
qpos = old_state.qpos.copy()
qvel = old_state.qvel.copy()
self.set_state(qpos, qvel)
self.goal = context
z_extended_goal_pos = np.concatenate((self.goal[:self.ctxt_dim], [0.77]))
if self.ctxt_dim == 4:
z_extended_goal_pos = np.concatenate((z_extended_goal_pos, [0.77]))
self.goal = z_extended_goal_pos
self.sim.model.body_pos[5] = self.goal[:3] # TODO: Missing: Setting the desired incomoing landing position
self.sim.forward()
return self._get_obs()