holereach025 add success flag
This commit is contained in:
parent
f6cef69225
commit
448ebcde95
@ -94,19 +94,23 @@ class HoleReacher(gym.Env):
|
|||||||
|
|
||||||
# compute reward directly in step function
|
# compute reward directly in step function
|
||||||
|
|
||||||
|
success = False
|
||||||
reward = 0
|
reward = 0
|
||||||
if not self._is_collided:
|
if not self._is_collided:
|
||||||
if self._steps == 199:
|
if self._steps == 199:
|
||||||
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2
|
dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
|
||||||
|
reward = - dist ** 2
|
||||||
|
success = dist < 0.005
|
||||||
else:
|
else:
|
||||||
|
dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
|
||||||
# if self.collision_penalty != 0:
|
# if self.collision_penalty != 0:
|
||||||
# reward = -self.collision_penalty
|
# reward = -self.collision_penalty
|
||||||
# else:
|
# else:
|
||||||
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2 - self.collision_penalty
|
reward = - dist ** 2 - self.collision_penalty
|
||||||
|
|
||||||
reward -= 5e-8 * np.sum(acc ** 2)
|
reward -= 5e-8 * np.sum(acc ** 2)
|
||||||
|
|
||||||
info = {"is_collided": self._is_collided}
|
info = {"is_collided": self._is_collided, "is_success": success}
|
||||||
|
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
|
# params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
|
||||||
|
|
||||||
rew, info = test_env.rollout(params, render=True)
|
rew, info = test_env.rollout(params, render=False)
|
||||||
print(rew)
|
print(rew)
|
||||||
|
|
||||||
# out = env(params)
|
# out = env(params)
|
||||||
|
Loading…
Reference in New Issue
Block a user