context wip
This commit is contained in:
parent
b4ad3e6ddd
commit
7f512068c9
@ -26,7 +26,7 @@ class EpisodicSimpleReacherEnv(SimpleReacherEnv):
|
|||||||
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def init_qpos(self):
|
||||||
return self._start_pos
|
return self._start_pos
|
||||||
|
|
||||||
# @property
|
# @property
|
||||||
|
@ -62,6 +62,10 @@ class HoleReacher(gym.Env):
|
|||||||
|
|
||||||
self.patches = [rect_1, rect_2, rect_3]
|
self.patches = [rect_1, rect_2, rect_3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_qpos(self):
|
||||||
|
return self.start_pos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_effector(self):
|
def end_effector(self):
|
||||||
return self._joints[self.n_links].T
|
return self._joints[self.n_links].T
|
||||||
@ -82,7 +86,7 @@ class HoleReacher(gym.Env):
|
|||||||
"""
|
"""
|
||||||
a single step with an action in joint velocity space
|
a single step with an action in joint velocity space
|
||||||
"""
|
"""
|
||||||
vel = action # + 0.01 * np.random.randn(self.num_links)
|
vel = action # + 0.05 * np.random.randn(self.n_links)
|
||||||
acc = (vel - self._angle_velocity) / self.dt
|
acc = (vel - self._angle_velocity) / self.dt
|
||||||
self._angle_velocity = vel
|
self._angle_velocity = vel
|
||||||
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
|
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
|
||||||
@ -237,7 +241,10 @@ class HoleReacher(gym.Env):
|
|||||||
if self._steps == 1:
|
if self._steps == 1:
|
||||||
# fig, ax = plt.subplots()
|
# fig, ax = plt.subplots()
|
||||||
# Add the patch to the Axes
|
# Add the patch to the Axes
|
||||||
[plt.gca().add_patch(rect) for rect in self.patches]
|
try:
|
||||||
|
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
# plt.pause(0.01)
|
# plt.pause(0.01)
|
||||||
|
|
||||||
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:
|
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:
|
||||||
|
@ -22,7 +22,7 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.link_lengths = np.ones(n_links)
|
self.link_lengths = np.ones(n_links)
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
self.dt = 0.1
|
self.dt = 0.01
|
||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
@ -56,10 +56,13 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
|
|
||||||
# action = self._add_action_noise(action)
|
# action = self._add_action_noise(action)
|
||||||
action = np.clip(action, -self.max_torque, self.max_torque)
|
# action = np.clip(action, -self.max_torque, self.max_torque)
|
||||||
|
vel = action
|
||||||
|
|
||||||
self._angle_velocity = self._angle_velocity + self.dt * action
|
# self._angle_velocity = self._angle_velocity + self.dt * action
|
||||||
self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
|
# self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
|
||||||
|
self._angle_velocity = vel
|
||||||
|
self._joint_angle = self._joint_angle + self.dt * self._angle_velocity
|
||||||
self._update_joints()
|
self._update_joints()
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
@ -111,7 +114,7 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
# reward_dist = np.exp(-0.1 * diff ** 2).mean()
|
# reward_dist = np.exp(-0.1 * diff ** 2).mean()
|
||||||
# reward_dist = - (diff ** 2).mean()
|
# reward_dist = - (diff ** 2).mean()
|
||||||
|
|
||||||
reward_ctrl = (action ** 2).sum()
|
reward_ctrl = 1e-5 * (action ** 2).sum()
|
||||||
reward = reward_dist - reward_ctrl
|
reward = reward_dist - reward_ctrl
|
||||||
return reward, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
return reward, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||||
|
|
||||||
@ -139,7 +142,7 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
# Sample uniformly in circle with radius R around center of reacher.
|
# Sample uniformly in circle with radius R around center of reacher.
|
||||||
R = np.sum(self.link_lengths)
|
R = np.sum(self.link_lengths)
|
||||||
r = R * np.sqrt(self.np_random.uniform())
|
r = R * np.sqrt(self.np_random.uniform())
|
||||||
theta = self.np_random.uniform() * 2 * np.pi
|
theta = np.pi/2 + 0.001 * np.random.randn() # self.np_random.uniform() * 2 * np.pi
|
||||||
return center + r * np.stack([np.cos(theta), np.sin(theta)])
|
return center + r * np.stack([np.cos(theta), np.sin(theta)])
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@ -170,8 +173,8 @@ class SimpleReacherEnv(gym.Env):
|
|||||||
plt.xlim([-lim, lim])
|
plt.xlim([-lim, lim])
|
||||||
plt.ylim([-lim, lim])
|
plt.ylim([-lim, lim])
|
||||||
# plt.draw()
|
# plt.draw()
|
||||||
# plt.pause(1e-4) pushes window to foreground, which is annoying.
|
plt.pause(1e-4) # pushes window to foreground, which is annoying.
|
||||||
self.fig.canvas.flush_events()
|
# self.fig.canvas.flush_events()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
del self.fig
|
del self.fig
|
||||||
|
@ -56,6 +56,7 @@ class AlrMpEnvSampler:
|
|||||||
|
|
||||||
vals = defaultdict(list)
|
vals = defaultdict(list)
|
||||||
for p in split_params:
|
for p in split_params:
|
||||||
|
self.env.reset()
|
||||||
obs, reward, done, info = self.env.step(p)
|
obs, reward, done, info = self.env.step(p)
|
||||||
vals['obs'].append(obs)
|
vals['obs'].append(obs)
|
||||||
vals['reward'].append(reward)
|
vals['reward'].append(reward)
|
||||||
@ -82,8 +83,9 @@ class AlrContextualMpEnvSampler:
|
|||||||
vals = defaultdict(list)
|
vals = defaultdict(list)
|
||||||
for i in range(repeat):
|
for i in range(repeat):
|
||||||
new_contexts = self.env.reset()
|
new_contexts = self.env.reset()
|
||||||
|
vals['new_contexts'].append(new_contexts)
|
||||||
new_samples = dist.sample(new_contexts)
|
new_samples, new_contexts = dist.sample(new_contexts)
|
||||||
|
vals['new_samples'].append(new_samples)
|
||||||
|
|
||||||
obs, reward, done, info = self.env.step(new_samples)
|
obs, reward, done, info = self.env.step(new_samples)
|
||||||
vals['obs'].append(obs)
|
vals['obs'].append(obs)
|
||||||
@ -92,7 +94,8 @@ class AlrContextualMpEnvSampler:
|
|||||||
vals['info'].append(info)
|
vals['info'].append(info)
|
||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples],\
|
return np.vstack(vals['new_samples'])[:n_samples], np.vstack(vals['new_contexts'])[:n_samples], \
|
||||||
|
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
||||||
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
|
|
||||||
def mp_rollout(self, action):
|
def mp_rollout(self, action):
|
||||||
# if self.mp.start_pos is None:
|
# if self.mp.start_pos is None:
|
||||||
self.mp.dmp_start_pos = self.env.init_qpos # start_pos
|
self.mp.dmp_start_pos = self.env.init_qpos.reshape((1, self.num_dof)) # start_pos
|
||||||
goal_pos, weight_matrix = self.goal_and_weights(action)
|
goal_pos, weight_matrix = self.goal_and_weights(action)
|
||||||
self.mp.set_weights(weight_matrix, goal_pos)
|
self.mp.set_weights(weight_matrix, goal_pos)
|
||||||
return self.mp.reference_trajectory(self.t)
|
return self.mp.reference_trajectory(self.t)
|
||||||
|
@ -22,7 +22,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
# self.num_dof = num_dof
|
self.num_dof = num_dof
|
||||||
# self.num_basis = num_basis
|
# self.num_basis = num_basis
|
||||||
# self.duration = duration # seconds
|
# self.duration = duration # seconds
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
# for p, c in zip(params, contexts):
|
# for p, c in zip(params, contexts):
|
||||||
for p in params:
|
for p in params:
|
||||||
# self.configure(c)
|
# self.configure(c)
|
||||||
|
# context = self.reset()
|
||||||
ob, reward, done, info = self.step(p)
|
ob, reward, done, info = self.step(p)
|
||||||
obs.append(ob)
|
obs.append(ob)
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
|
@ -83,6 +83,6 @@ if __name__ == '__main__':
|
|||||||
# example_mujoco()
|
# example_mujoco()
|
||||||
# example_dmp()
|
# example_dmp()
|
||||||
# example_async()
|
# example_async()
|
||||||
# env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1)
|
env = gym.make("alr_envs:HoleReacherDMP-v0")
|
||||||
env = gym.make("alr_envs:SimpleReacherDMP-v1")
|
# env = gym.make("alr_envs:SimpleReacherDMP-v1")
|
||||||
print()
|
print()
|
Loading…
Reference in New Issue
Block a user