corrected reward for hopperjumprndminit + ALRReacher for iLQR
This commit is contained in:
parent
77927e9157
commit
c0502cf1d4
@ -148,6 +148,17 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRReacherSparseOptCtrl-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRReacherOptCtrlEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"steps_before_reward": 200,
|
||||
"n_links": 5,
|
||||
"balance": False,
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRReacherSparseBalanced-v0',
|
||||
entry_point='alr_envs.alr.mujoco:ALRReacherEnv',
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .reacher.alr_reacher import ALRReacherEnv
|
||||
from .reacher.alr_reacher import ALRReacherEnv, ALRReacherOptCtrlEnv
|
||||
from .reacher.balancing import BalancingEnv
|
||||
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
||||
|
@ -94,10 +94,16 @@ class ALRHopperJumpEnv(HopperEnv):
|
||||
|
||||
class ALRHopperJumpRndmPosEnv(ALRHopperJumpEnv):
|
||||
def __init__(self, max_episode_steps=250):
|
||||
self.contact_with_floor = False
|
||||
self._floor_geom_id = None
|
||||
self._foot_geom_id = None
|
||||
super(ALRHopperJumpRndmPosEnv, self).__init__(exclude_current_positions_from_observation=False,
|
||||
reset_noise_scale=5e-1,
|
||||
max_episode_steps=max_episode_steps)
|
||||
|
||||
def reset_model(self):
|
||||
self._floor_geom_id = self.model.geom_name2id('floor')
|
||||
self._foot_geom_id = self.model.geom_name2id('foot_geom')
|
||||
noise_low = -self._reset_noise_scale
|
||||
noise_high = self._reset_noise_scale
|
||||
rnd_vec = self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq)
|
||||
@ -116,8 +122,12 @@ class ALRHopperJumpRndmPosEnv(ALRHopperJumpEnv):
|
||||
|
||||
self.current_step += 1
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
|
||||
self.contact_with_floor = self._contact_checker(self._floor_geom_id, self._foot_geom_id) if not \
|
||||
self.contact_with_floor else True
|
||||
|
||||
height_after = self.get_body_com("torso")[2]
|
||||
self.max_height = max(height_after, self.max_height)
|
||||
self.max_height = max(height_after, self.max_height) if self.contact_with_floor else 0
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
costs = ctrl_cost
|
||||
@ -142,9 +152,19 @@ class ALRHopperJumpRndmPosEnv(ALRHopperJumpEnv):
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
def _contact_checker(self, id_1, id_2):
|
||||
for coni in range(0, self.sim.data.ncon):
|
||||
con = self.sim.data.contact[coni]
|
||||
collision = con.geom1 == id_1 and con.geom2 == id_2
|
||||
collision_trans = con.geom1 == id_2 and con.geom2 == id_1
|
||||
if collision or collision_trans:
|
||||
return True
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = ALRHopperJumpEnv()
|
||||
# env = ALRHopperJumpEnv()
|
||||
env = ALRHopperJumpRndmPosEnv()
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(2000):
|
||||
@ -152,7 +172,8 @@ if __name__ == '__main__':
|
||||
# test with random actions
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
# if i % 10 == 0:
|
||||
# env.render(mode=render_mode)
|
||||
env.render(mode=render_mode)
|
||||
if d:
|
||||
print('After ', i, ' steps, done: ', d)
|
||||
|
@ -87,6 +87,27 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
[self._steps],
|
||||
])
|
||||
|
||||
class ALRReacherOptCtrlEnv(ALRReacherEnv):
|
||||
def __init__(self, steps_before_reward=200, n_links=5, balance=False):
|
||||
super(ALRReacherOptCtrlEnv, self).__init__(steps_before_reward, n_links, balance)
|
||||
self.goal = np.array([0.1, 0.1])
|
||||
|
||||
def _get_obs(self):
|
||||
theta = self.sim.data.qpos.flat[:self.n_links]
|
||||
return np.concatenate([
|
||||
theta,
|
||||
self.sim.data.qvel.flat[:self.n_links], # this is angular velocity
|
||||
])
|
||||
|
||||
def reset_model(self):
|
||||
qpos = self.init_qpos
|
||||
qpos[-2:] = self.goal
|
||||
qvel = self.init_qvel
|
||||
qvel[-2:] = 0
|
||||
self.set_state(qpos, qvel)
|
||||
self._steps = 0
|
||||
|
||||
return self._get_obs()
|
||||
|
||||
if __name__ == '__main__':
|
||||
nl = 5
|
||||
|
Loading…
Reference in New Issue
Block a user