small bp and tt updates
This commit is contained in:
parent
a0af743585
commit
92d05a9dfd
@ -204,7 +204,7 @@ register(id='TableTennis2DCtxt-v0',
|
||||
|
||||
register(id='TableTennis2DCtxt-v1',
|
||||
entry_point='alr_envs.alr.mujoco:TTEnvGym',
|
||||
max_episode_steps=1750,
|
||||
max_episode_steps=MAX_EPISODE_STEPS,
|
||||
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
|
||||
|
||||
register(id='TableTennis4DCtxt-v0',
|
||||
@ -365,11 +365,14 @@ for _v in _versions:
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
## Beerpong
|
||||
_versions = ["v0", "v1", "v2", "v3"]
|
||||
for _v in _versions:
|
||||
_env_id = f'BeerpongProMP-{_v}'
|
||||
register(
|
||||
id='BeerpongProMP-v0',
|
||||
id=_env_id,
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBeerPong-v0",
|
||||
"name": f"alr_envs:ALRBeerPong-{_v}",
|
||||
"wrappers": [mujoco.beerpong.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
@ -377,7 +380,7 @@ register(
|
||||
"duration": 1,
|
||||
"post_traj_time": 2,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"weights_scale": 1,
|
||||
"zero_start": True,
|
||||
"zero_goal": False,
|
||||
"policy_kwargs": {
|
||||
@ -387,7 +390,7 @@ register(
|
||||
}
|
||||
}
|
||||
)
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("BeerpongProMP-v0")
|
||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
## Table Tennis
|
||||
ctxt_dim = [2, 4]
|
||||
@ -429,7 +432,9 @@ register(
|
||||
"duration": 1.,
|
||||
"post_traj_time": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"weights_scale": 1,
|
||||
"off": -0.05,
|
||||
"bandwidth_factor": 3.5,
|
||||
"zero_start": True,
|
||||
"zero_goal": False,
|
||||
"policy_kwargs": {
|
||||
|
@ -127,13 +127,14 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
self._steps += 1
|
||||
else:
|
||||
reward = -30
|
||||
reward_infos = dict()
|
||||
success = False
|
||||
is_collided = False
|
||||
done = True
|
||||
ball_pos = np.zeros(3)
|
||||
ball_vel = np.zeros(3)
|
||||
|
||||
return ob, reward, done, dict(reward_dist=reward_dist,
|
||||
infos = dict(reward_dist=reward_dist,
|
||||
reward_ctrl=reward_ctrl,
|
||||
reward=reward,
|
||||
velocity=angular_vel,
|
||||
@ -145,6 +146,9 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
ball_vel=ball_vel,
|
||||
success=success,
|
||||
is_collided=is_collided, sim_crash=crash)
|
||||
infos.update(reward_infos)
|
||||
|
||||
return ob, reward, done, infos
|
||||
|
||||
def check_traj_in_joint_limits(self):
|
||||
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
||||
|
@ -110,7 +110,7 @@ class BeerPongReward:
|
||||
success = ball_in_cup
|
||||
crash = self._is_collided
|
||||
else:
|
||||
reward = - 1e-4 * action_cost
|
||||
reward = - 1e-2 * action_cost
|
||||
success = False
|
||||
crash = False
|
||||
|
||||
|
@ -10,7 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
|
||||
|
||||
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
||||
|
||||
MAX_EPISODE_STEPS = 2875
|
||||
MAX_EPISODE_STEPS = 1750
|
||||
BALL_NAME_CONTACT = "target_ball_contact"
|
||||
BALL_NAME = "target_ball"
|
||||
TABLE_NAME = 'table_tennis_table'
|
||||
@ -42,9 +42,10 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
||||
else:
|
||||
raise ValueError("either 2 or 4 dimensional Contexts available")
|
||||
|
||||
action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||
action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||
self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
|
||||
# has no effect as it is overwritten in init of super
|
||||
# action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||
# action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||
# self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
|
||||
|
||||
self.time_steps = 0
|
||||
self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0])
|
||||
@ -159,7 +160,10 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
||||
done = True
|
||||
reward = -25
|
||||
ob = self._get_obs()
|
||||
return ob, reward, done, {"hit_ball": self.hit_ball} # might add some information here ....
|
||||
info = {"hit_ball": self.hit_ball,
|
||||
"q_pos": np.copy(self.sim.data.qpos[:7]),
|
||||
"ball_pos": np.copy(self.sim.data.qpos[7:])}
|
||||
return ob, reward, done, info # might add some information here ....
|
||||
|
||||
def set_context(self, context):
|
||||
old_state = self.sim.get_state()
|
||||
|
Loading…
Reference in New Issue
Block a user