bp frameskip version
This commit is contained in:
parent
863ef77e5e
commit
3cc1cd1456
@ -400,10 +400,11 @@ register(id='TableTennis4DCtxt-v0',
|
||||
register(
|
||||
id='ALRBeerPong-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||
max_episode_steps=600,
|
||||
max_episode_steps=150,
|
||||
kwargs={
|
||||
"rndm_goal": False,
|
||||
"cup_goal_pos": [0.1, -2.0]
|
||||
"cup_goal_pos": [0.1, -2.0],
|
||||
"frameskip": 4
|
||||
}
|
||||
)
|
||||
|
||||
@ -412,10 +413,11 @@ register(
|
||||
register(
|
||||
id='ALRBeerPong-v1',
|
||||
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||
max_episode_steps=600,
|
||||
max_episode_steps=150,
|
||||
kwargs={
|
||||
"rndm_goal": True,
|
||||
"cup_goal_pos": [-0.3, -1.2]
|
||||
"cup_goal_pos": [-0.3, -1.2],
|
||||
"frameskip": 4
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
# self._release_step = 130 # time step of ball release
|
||||
self.release_step = 100 # time step of ball release
|
||||
|
||||
self.ep_length = 600 # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt)
|
||||
self.ep_length = 600//frame_skip
|
||||
self.cup_table_id = 10
|
||||
|
||||
if noisy:
|
||||
@ -59,8 +59,8 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
self.noise_std = 0
|
||||
reward_function = BeerPongReward
|
||||
self.reward_function = reward_function()
|
||||
|
||||
MujocoEnv.__init__(self, self.xml_path, frame_skip)
|
||||
self.repeat_action = frame_skip
|
||||
MujocoEnv.__init__(self, self.xml_path, frame_skip=1)
|
||||
utils.EzPickle.__init__(self)
|
||||
|
||||
@property
|
||||
@ -106,26 +106,26 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def step(self, a):
|
||||
# if a.shape[0] == 8: # we learn also when to release the ball
|
||||
# self._release_step = a[-1]
|
||||
# self._release_step = np.clip(self._release_step, 50, 250)
|
||||
# self.release_step = 0.5/self.dt
|
||||
reward_dist = 0.0
|
||||
angular_vel = 0.0
|
||||
applied_action = a
|
||||
reward_ctrl = - np.square(applied_action).sum()
|
||||
if self.apply_gravity_comp:
|
||||
applied_action += self.sim.data.qfrc_bias[:len(applied_action)].copy() / self.model.actuator_gear[:, 0]
|
||||
try:
|
||||
self.do_simulation(applied_action, self.frame_skip)
|
||||
if self._steps < self.release_step:
|
||||
self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
||||
self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.ball_site_id, :].copy()
|
||||
elif self._steps == self.release_step and self.add_noise:
|
||||
self.sim.data.qvel[7::] += self.noise_std * np.random.randn(3)
|
||||
crash = False
|
||||
except mujoco_py.builder.MujocoException:
|
||||
crash = True
|
||||
|
||||
for _ in range(self.repeat_action):
|
||||
if self.apply_gravity_comp:
|
||||
applied_action = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
|
||||
else:
|
||||
applied_action = a
|
||||
try:
|
||||
self.do_simulation(applied_action, self.frame_skip)
|
||||
self.reward_function.initialize(self)
|
||||
self.reward_function.check_contacts(self.sim)
|
||||
if self._steps < self.release_step:
|
||||
self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
|
||||
self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.ball_site_id, :].copy()
|
||||
elif self._steps == self.release_step and self.add_noise:
|
||||
self.sim.data.qvel[7::] += self.noise_std * np.random.randn(3)
|
||||
crash = False
|
||||
except mujoco_py.builder.MujocoException:
|
||||
crash = True
|
||||
# joint_cons_viol = self.check_traj_in_joint_limits()
|
||||
|
||||
ob = self._get_obs()
|
||||
@ -148,7 +148,6 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
ball_vel = np.zeros(3)
|
||||
|
||||
infos = dict(reward_dist=reward_dist,
|
||||
reward_ctrl=reward_ctrl,
|
||||
reward=reward,
|
||||
velocity=angular_vel,
|
||||
# traj=self._q_pos,
|
||||
@ -176,16 +175,14 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||
[self._steps],
|
||||
])
|
||||
|
||||
# TODO
|
||||
@property
|
||||
def active_obs(self):
|
||||
return np.hstack([
|
||||
[False] * 7, # cos
|
||||
[False] * 7, # sin
|
||||
[True] * 2, # xy position of cup
|
||||
[False] # env steps
|
||||
])
|
||||
def dt(self):
|
||||
return super(ALRBeerBongEnv, self).dt()*self.repeat_action
|
||||
|
||||
class ALRBeerPongStepEnv(ALRBeerBongEnv):
|
||||
def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False,
|
||||
rndm_goal=False, cup_goal_pos=None):
|
||||
super(ALRBeerPongStepEnv, self).__init__(frame_skip, apply_gravity_comp, noisy, rndm_goal, cup_goal_pos)
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = ALRBeerBongEnv(rndm_goal=True)
|
||||
|
@ -60,7 +60,7 @@ class BeerPongReward:
|
||||
self.noisy_bp = noisy
|
||||
self._t_min_final_dist = -1
|
||||
|
||||
def compute_reward(self, env, action):
|
||||
def initialize(self, env):
|
||||
|
||||
if not self.is_initialized:
|
||||
self.is_initialized = True
|
||||
@ -77,12 +77,14 @@ class BeerPongReward:
|
||||
self.ground_collision_id = env.sim.model._geom_name2id["ground"]
|
||||
self.robot_collision_ids = [env.sim.model._geom_name2id[name] for name in self.robot_collision_objects]
|
||||
|
||||
def compute_reward(self, env, action):
|
||||
|
||||
goal_pos = env.sim.data.site_xpos[self.goal_id]
|
||||
ball_pos = env.sim.data.body_xpos[self.ball_id]
|
||||
ball_vel = env.sim.data.body_xvelp[self.ball_id]
|
||||
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
|
||||
|
||||
self._check_contacts(env.sim)
|
||||
self.check_contacts(env.sim)
|
||||
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
||||
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
||||
self.dist_ground_cup = np.linalg.norm(ball_pos-goal_pos) \
|
||||
@ -137,7 +139,7 @@ class BeerPongReward:
|
||||
|
||||
return reward, infos
|
||||
|
||||
def _check_contacts(self, sim):
|
||||
def check_contacts(self, sim):
|
||||
if not self.ball_table_contact:
|
||||
self.ball_table_contact = self._check_collision_single_objects(sim, self.ball_collision_id,
|
||||
self.table_collision_id)
|
||||
|
Loading…
Reference in New Issue
Block a user