slim down beerpong constructor further. Not sure, if we should merge the reward into the environment class.
This commit is contained in:
parent
d80df03145
commit
4dc33b0e97
@ -123,7 +123,7 @@
|
|||||||
<joint name="wam/palm_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7 2.7" />
|
<joint name="wam/palm_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7 2.7" />
|
||||||
<geom class="viz" pos="0 0 -0.06" mesh="wrist_palm_link_fine" />
|
<geom class="viz" pos="0 0 -0.06" mesh="wrist_palm_link_fine" />
|
||||||
<geom class="col" pos="0 0 -0.06" mesh="wrist_palm_link_convex" name="wrist_palm_link_convex_geom" />
|
<geom class="col" pos="0 0 -0.06" mesh="wrist_palm_link_convex" name="wrist_palm_link_convex_geom" />
|
||||||
<site name="init_ball_pos_site" size="0.025 0.025 0.025" pos="0.0 0.0 0.035" rgba="0 1 0 1"/>
|
<site name="init_ball_pos" size="0.025 0.025 0.025" pos="0.0 0.0 0.035" rgba="0 1 0 1"/>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
|
@ -123,7 +123,7 @@
|
|||||||
<joint name="wam/palm_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7 2.7" />
|
<joint name="wam/palm_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7 2.7" />
|
||||||
<geom class="viz" pos="0 0 -0.06" mesh="wrist_palm_link_fine" />
|
<geom class="viz" pos="0 0 -0.06" mesh="wrist_palm_link_fine" />
|
||||||
<geom class="col" pos="0 0 -0.06" mesh="wrist_palm_link_convex" name="wrist_palm_link_convex_geom" />
|
<geom class="col" pos="0 0 -0.06" mesh="wrist_palm_link_convex" name="wrist_palm_link_convex_geom" />
|
||||||
<site name="init_ball_pos_site" size="0.025 0.025 0.025" pos="0.0 0.0 0.035" rgba="0 1 0 1"/>
|
<site name="init_ball_pos" size="0.025 0.025 0.025" pos="0.0 0.0 0.035" rgba="0 1 0 1"/>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
|
@ -9,35 +9,31 @@ from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
|
|||||||
|
|
||||||
|
|
||||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||||
# TODO: always use gravity compensation
|
def __init__(self, frame_skip=2):
|
||||||
def __init__(self, frame_skip=2, apply_gravity_comp=True):
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
# Small Context -> Easier. Todo: Should we do different versions?
|
# Small Context -> Easier. Todo: Should we do different versions?
|
||||||
# self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
# self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||||
# "beerpong_wo_cup" + ".xml")
|
# "beerpong_wo_cup" + ".xml")
|
||||||
# self.cup_pos_min = np.array([-0.32, -2.2])
|
# self._cup_pos_min = np.array([-0.32, -2.2])
|
||||||
# self.cup_pos_max = np.array([0.32, -1.2])
|
# self._cup_pos_max = np.array([0.32, -1.2])
|
||||||
|
|
||||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||||
"beerpong_wo_cup_big_table" + ".xml")
|
"beerpong_wo_cup_big_table" + ".xml")
|
||||||
self._cup_pos_min = np.array([-1.42, -4.05])
|
self._cup_pos_min = np.array([-1.42, -4.05])
|
||||||
self._cup_pos_max = np.array([1.42, -1.25])
|
self._cup_pos_max = np.array([1.42, -1.25])
|
||||||
|
|
||||||
self.apply_gravity_comp = apply_gravity_comp
|
|
||||||
|
|
||||||
self._start_pos = np.array([0.0, 1.35, 0.0, 1.18, 0.0, -0.786, -1.59])
|
self._start_pos = np.array([0.0, 1.35, 0.0, 1.18, 0.0, -0.786, -1.59])
|
||||||
self._start_vel = np.zeros(7)
|
self._start_vel = np.zeros(7)
|
||||||
|
|
||||||
# TODO: check if we need to define that in the constructor?
|
|
||||||
self.ball_site_id = 0
|
|
||||||
self.ball_id = 11
|
|
||||||
self.cup_table_id = 10
|
|
||||||
|
|
||||||
self.release_step = 100 # time step of ball release
|
self.release_step = 100 # time step of ball release
|
||||||
self.ep_length = 600 // frame_skip
|
self.ep_length = 600 // frame_skip
|
||||||
|
|
||||||
self.reward_function = BeerPongReward()
|
self.reward_function = BeerPongReward()
|
||||||
self.repeat_action = frame_skip
|
self.repeat_action = frame_skip
|
||||||
|
self.model = None
|
||||||
|
self.site_id = lambda x: self.model.site_name2id(x)
|
||||||
|
self.body_id = lambda x: self.model.body_name2id(x)
|
||||||
|
|
||||||
MujocoEnv.__init__(self, self.xml_path, frame_skip=1)
|
MujocoEnv.__init__(self, self.xml_path, frame_skip=1)
|
||||||
utils.EzPickle.__init__(self)
|
utils.EzPickle.__init__(self)
|
||||||
|
|
||||||
@ -65,29 +61,26 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
# TODO: Ask Max why we need to set the state twice.
|
# TODO: Ask Max why we need to set the state twice.
|
||||||
self.set_state(start_pos, init_vel)
|
self.set_state(start_pos, init_vel)
|
||||||
start_pos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
start_pos[7::] = self.sim.data.site_xpos[self.site_id("init_ball_pos"), :].copy()
|
||||||
self.set_state(start_pos, init_vel)
|
self.set_state(start_pos, init_vel)
|
||||||
xy = self.np_random.uniform(self._cup_pos_min, self._cup_pos_max)
|
xy = self.np_random.uniform(self._cup_pos_min, self._cup_pos_max)
|
||||||
xyz = np.zeros(3)
|
xyz = np.zeros(3)
|
||||||
xyz[:2] = xy
|
xyz[:2] = xy
|
||||||
xyz[-1] = 0.840
|
xyz[-1] = 0.840
|
||||||
self.sim.model.body_pos[self.cup_table_id] = xyz
|
self.sim.model.body_pos[self.body_id("cup_table")] = xyz
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
crash = False
|
crash = False
|
||||||
for _ in range(self.repeat_action):
|
for _ in range(self.repeat_action):
|
||||||
if self.apply_gravity_comp:
|
applied_action = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
||||||
applied_action = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
|
||||||
else:
|
|
||||||
applied_action = a
|
|
||||||
try:
|
try:
|
||||||
self.do_simulation(applied_action, self.frame_skip)
|
self.do_simulation(applied_action, self.frame_skip)
|
||||||
self.reward_function.initialize(self)
|
self.reward_function.initialize(self)
|
||||||
# self.reward_function.check_contacts(self.sim) # I assume this is not important?
|
# self.reward_function.check_contacts(self.sim) # I assume this is not important?
|
||||||
if self._steps < self.release_step:
|
if self._steps < self.release_step:
|
||||||
self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.site_id("init_ball_pos"), :].copy()
|
||||||
self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.ball_site_id, :].copy()
|
self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.site_id("init_ball_pos"), :].copy()
|
||||||
crash = False
|
crash = False
|
||||||
except mujoco_py.builder.MujocoException:
|
except mujoco_py.builder.MujocoException:
|
||||||
crash = True
|
crash = True
|
||||||
@ -102,7 +95,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
else:
|
else:
|
||||||
reward = -30
|
reward = -30
|
||||||
done = True
|
done = True
|
||||||
reward_infos = {"success": False, "ball_pos": np.zeros(3), "ball_vel": np.zeros(3), "is_collided": False}
|
reward_infos = {"success": False, "ball_pos": np.zeros(3), "ball_vel": np.zeros(3), "is_collided": False}
|
||||||
|
|
||||||
infos = dict(
|
infos = dict(
|
||||||
reward=reward,
|
reward=reward,
|
||||||
@ -125,7 +118,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
theta_dot,
|
theta_dot,
|
||||||
cup_goal_diff_final,
|
cup_goal_diff_final,
|
||||||
cup_goal_diff_top,
|
cup_goal_diff_top,
|
||||||
self.sim.model.body_pos[self.cup_table_id][:2].copy(),
|
self.sim.model.body_pos[self.body_id("cup_table")][:2].copy(),
|
||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -135,14 +128,14 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
class BeerPongEnvFixedReleaseStep(BeerPongEnv):
|
class BeerPongEnvFixedReleaseStep(BeerPongEnv):
|
||||||
def __init__(self, frame_skip=2, apply_gravity_comp=True):
|
def __init__(self, frame_skip=2):
|
||||||
super().__init__(frame_skip, apply_gravity_comp)
|
super().__init__(frame_skip)
|
||||||
self.release_step = 62 # empirically evaluated for frame_skip=2!
|
self.release_step = 62 # empirically evaluated for frame_skip=2!
|
||||||
|
|
||||||
|
|
||||||
class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
||||||
def __init__(self, frame_skip=2, apply_gravity_comp=True):
|
def __init__(self, frame_skip=2):
|
||||||
super().__init__(frame_skip, apply_gravity_comp)
|
super().__init__(frame_skip)
|
||||||
self.release_step = 62 # empirically evaluated for frame_skip=2!
|
self.release_step = 62 # empirically evaluated for frame_skip=2!
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
|
@ -46,3 +46,4 @@ def nested_update(base: MutableMapping, update):
|
|||||||
"""
|
"""
|
||||||
for k, v in update.items():
|
for k, v in update.items():
|
||||||
base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v
|
base[k] = nested_update(base.get(k, {}), v) if isinstance(v, Mapping) else v
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user