bp frameskip version

This commit is contained in:
Onur 2022-05-29 12:15:04 +02:00
parent 863ef77e5e
commit 3cc1cd1456
3 changed files with 38 additions and 37 deletions

View File

@ -400,10 +400,11 @@ register(id='TableTennis4DCtxt-v0',
register( register(
id='ALRBeerPong-v0', id='ALRBeerPong-v0',
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv', entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
max_episode_steps=600, max_episode_steps=150,
kwargs={ kwargs={
"rndm_goal": False, "rndm_goal": False,
"cup_goal_pos": [0.1, -2.0] "cup_goal_pos": [0.1, -2.0],
"frameskip": 4
} }
) )
@ -412,10 +413,11 @@ register(
register( register(
id='ALRBeerPong-v1', id='ALRBeerPong-v1',
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv', entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
max_episode_steps=600, max_episode_steps=150,
kwargs={ kwargs={
"rndm_goal": True, "rndm_goal": True,
"cup_goal_pos": [-0.3, -1.2] "cup_goal_pos": [-0.3, -1.2],
"frameskip": 4
} }
) )

View File

@ -50,7 +50,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
# self._release_step = 130 # time step of ball release # self._release_step = 130 # time step of ball release
self.release_step = 100 # time step of ball release self.release_step = 100 # time step of ball release
self.ep_length = 600 # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt) self.ep_length = 600//frame_skip
self.cup_table_id = 10 self.cup_table_id = 10
if noisy: if noisy:
@ -59,8 +59,8 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
self.noise_std = 0 self.noise_std = 0
reward_function = BeerPongReward reward_function = BeerPongReward
self.reward_function = reward_function() self.reward_function = reward_function()
self.repeat_action = frame_skip
MujocoEnv.__init__(self, self.xml_path, frame_skip) MujocoEnv.__init__(self, self.xml_path, frame_skip=1)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@property @property
@ -106,18 +106,18 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def step(self, a): def step(self, a):
# if a.shape[0] == 8: # we learn also when to release the ball
# self._release_step = a[-1]
# self._release_step = np.clip(self._release_step, 50, 250)
# self.release_step = 0.5/self.dt
reward_dist = 0.0 reward_dist = 0.0
angular_vel = 0.0 angular_vel = 0.0
applied_action = a
reward_ctrl = - np.square(applied_action).sum() for _ in range(self.repeat_action):
if self.apply_gravity_comp: if self.apply_gravity_comp:
applied_action += self.sim.data.qfrc_bias[:len(applied_action)].copy() / self.model.actuator_gear[:, 0] applied_action = a + self.sim.data.qfrc_bias[:len(a)].copy() / self.model.actuator_gear[:, 0]
else:
applied_action = a
try: try:
self.do_simulation(applied_action, self.frame_skip) self.do_simulation(applied_action, self.frame_skip)
self.reward_function.initialize(self)
self.reward_function.check_contacts(self.sim)
if self._steps < self.release_step: if self._steps < self.release_step:
self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy() self.sim.data.qpos[7::] = self.sim.data.site_xpos[self.ball_site_id, :].copy()
self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.ball_site_id, :].copy() self.sim.data.qvel[7::] = self.sim.data.site_xvelp[self.ball_site_id, :].copy()
@ -148,7 +148,6 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
ball_vel = np.zeros(3) ball_vel = np.zeros(3)
infos = dict(reward_dist=reward_dist, infos = dict(reward_dist=reward_dist,
reward_ctrl=reward_ctrl,
reward=reward, reward=reward,
velocity=angular_vel, velocity=angular_vel,
# traj=self._q_pos, # traj=self._q_pos,
@ -176,16 +175,14 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
[self._steps], [self._steps],
]) ])
# TODO
@property @property
def active_obs(self): def dt(self):
return np.hstack([ return super(ALRBeerBongEnv, self).dt()*self.repeat_action
[False] * 7, # cos
[False] * 7, # sin
[True] * 2, # xy position of cup
[False] # env steps
])
class ALRBeerPongStepEnv(ALRBeerBongEnv):
def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False,
rndm_goal=False, cup_goal_pos=None):
super(ALRBeerPongStepEnv, self).__init__(frame_skip, apply_gravity_comp, noisy, rndm_goal, cup_goal_pos)
if __name__ == "__main__": if __name__ == "__main__":
env = ALRBeerBongEnv(rndm_goal=True) env = ALRBeerBongEnv(rndm_goal=True)

View File

@ -60,7 +60,7 @@ class BeerPongReward:
self.noisy_bp = noisy self.noisy_bp = noisy
self._t_min_final_dist = -1 self._t_min_final_dist = -1
def compute_reward(self, env, action): def initialize(self, env):
if not self.is_initialized: if not self.is_initialized:
self.is_initialized = True self.is_initialized = True
@ -77,12 +77,14 @@ class BeerPongReward:
self.ground_collision_id = env.sim.model._geom_name2id["ground"] self.ground_collision_id = env.sim.model._geom_name2id["ground"]
self.robot_collision_ids = [env.sim.model._geom_name2id[name] for name in self.robot_collision_objects] self.robot_collision_ids = [env.sim.model._geom_name2id[name] for name in self.robot_collision_objects]
def compute_reward(self, env, action):
goal_pos = env.sim.data.site_xpos[self.goal_id] goal_pos = env.sim.data.site_xpos[self.goal_id]
ball_pos = env.sim.data.body_xpos[self.ball_id] ball_pos = env.sim.data.body_xpos[self.ball_id]
ball_vel = env.sim.data.body_xvelp[self.ball_id] ball_vel = env.sim.data.body_xvelp[self.ball_id]
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id] goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
self._check_contacts(env.sim) self.check_contacts(env.sim)
self.dists.append(np.linalg.norm(goal_pos - ball_pos)) self.dists.append(np.linalg.norm(goal_pos - ball_pos))
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
self.dist_ground_cup = np.linalg.norm(ball_pos-goal_pos) \ self.dist_ground_cup = np.linalg.norm(ball_pos-goal_pos) \
@ -137,7 +139,7 @@ class BeerPongReward:
return reward, infos return reward, infos
def _check_contacts(self, sim): def check_contacts(self, sim):
if not self.ball_table_contact: if not self.ball_table_contact:
self.ball_table_contact = self._check_collision_single_objects(sim, self.ball_collision_id, self.ball_table_contact = self._check_collision_single_objects(sim, self.ball_collision_id,
self.table_collision_id) self.table_collision_id)