fixed some issues with SimpleReacher state space
This commit is contained in:
parent
5b83539109
commit
aec332ff0c
@ -8,7 +8,7 @@ from gym.utils import seeding
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if not os.environ.get("DISPLAY", None):
|
||||
if os.environ.get("DISPLAY", None):
|
||||
mpl.use('Qt5Agg')
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
|
||||
self._goal_pos = None
|
||||
|
||||
self.joints = None
|
||||
self._joints = None
|
||||
self._joint_angle = None
|
||||
self._angle_velocity = None
|
||||
|
||||
@ -37,7 +37,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
state_bound = np.hstack([
|
||||
[np.pi] * self.n_links,
|
||||
[np.inf] * self.n_links,
|
||||
[np.inf],
|
||||
[np.inf] * 2, # x-y coordinates of target distance
|
||||
[np.inf] # TODO: Maybe
|
||||
])
|
||||
self.action_space = spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape)
|
||||
@ -87,13 +87,13 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
|
||||
def _update_joints(self):
|
||||
"""
|
||||
update joints to get new end effector position. The other links are only required for rendering.
|
||||
update _joints to get new end effector position. The other links are only required for rendering.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
angles = np.cumsum(self._joint_angle)
|
||||
x = self.link_lengths * np.vstack([np.cos(angles), np.sin(angles)])
|
||||
self.joints[1:] = self.joints[0] + np.cumsum(x.T, axis=0)
|
||||
self._joints[1:] = self._joints[0] + np.cumsum(x.T, axis=0)
|
||||
|
||||
def _get_reward(self, action):
|
||||
diff = self.end_effector - self._goal_pos
|
||||
@ -113,14 +113,14 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
# Sample only orientation of first link, i.e. the arm is always straight.
|
||||
self._joint_angle = np.hstack([[self.np_random.uniform(-np.pi, np.pi)], np.zeros(self.n_links - 1)])
|
||||
self._angle_velocity = np.zeros(self.n_links)
|
||||
self.joints = np.zeros((self.n_links + 1, 2))
|
||||
self._joints = np.zeros((self.n_links + 1, 2))
|
||||
self._update_joints()
|
||||
|
||||
self._goal_pos = self._get_random_goal()
|
||||
return self._get_obs().copy()
|
||||
|
||||
def _get_random_goal(self):
|
||||
center = self.joints[0]
|
||||
center = self._joints[0]
|
||||
|
||||
# Sample uniformly in circle with radius R around center of reacher.
|
||||
R = np.sum(self.link_lengths)
|
||||
@ -143,7 +143,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
plt.cla()
|
||||
|
||||
# Arm
|
||||
plt.plot(self.joints[:, 0], self.joints[:, 1], 'ro-', markerfacecolor='k')
|
||||
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k')
|
||||
|
||||
# goal
|
||||
goal_pos = self._goal_pos.T
|
||||
@ -162,7 +162,7 @@ class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||
|
||||
@property
|
||||
def end_effector(self):
|
||||
return self.joints[self.n_links].T
|
||||
return self._joints[self.n_links].T
|
||||
|
||||
|
||||
def angle_normalize(x):
|
||||
|
Loading…
Reference in New Issue
Block a user