context wip
This commit is contained in:
parent
b4ad3e6ddd
commit
7f512068c9
@ -26,7 +26,7 @@ class EpisodicSimpleReacherEnv(SimpleReacherEnv):
|
||||
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
def init_qpos(self):
|
||||
return self._start_pos
|
||||
|
||||
# @property
|
||||
|
@ -62,6 +62,10 @@ class HoleReacher(gym.Env):
|
||||
|
||||
self.patches = [rect_1, rect_2, rect_3]
|
||||
|
||||
@property
|
||||
def init_qpos(self):
|
||||
return self.start_pos
|
||||
|
||||
@property
|
||||
def end_effector(self):
|
||||
return self._joints[self.n_links].T
|
||||
@ -82,7 +86,7 @@ class HoleReacher(gym.Env):
|
||||
"""
|
||||
a single step with an action in joint velocity space
|
||||
"""
|
||||
vel = action # + 0.01 * np.random.randn(self.num_links)
|
||||
vel = action # + 0.05 * np.random.randn(self.n_links)
|
||||
acc = (vel - self._angle_velocity) / self.dt
|
||||
self._angle_velocity = vel
|
||||
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
|
||||
@ -237,7 +241,10 @@ class HoleReacher(gym.Env):
|
||||
if self._steps == 1:
|
||||
# fig, ax = plt.subplots()
|
||||
# Add the patch to the Axes
|
||||
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||
try:
|
||||
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||
except RuntimeError:
|
||||
pass
|
||||
# plt.pause(0.01)
|
||||
|
||||
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:
|
||||
|
@ -22,7 +22,7 @@ class SimpleReacherEnv(gym.Env):
|
||||
super().__init__()
|
||||
self.link_lengths = np.ones(n_links)
|
||||
self.n_links = n_links
|
||||
self.dt = 0.1
|
||||
self.dt = 0.01
|
||||
|
||||
self.random_start = random_start
|
||||
|
||||
@ -56,10 +56,13 @@ class SimpleReacherEnv(gym.Env):
|
||||
def step(self, action: np.ndarray):
|
||||
|
||||
# action = self._add_action_noise(action)
|
||||
action = np.clip(action, -self.max_torque, self.max_torque)
|
||||
# action = np.clip(action, -self.max_torque, self.max_torque)
|
||||
vel = action
|
||||
|
||||
self._angle_velocity = self._angle_velocity + self.dt * action
|
||||
self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
|
||||
# self._angle_velocity = self._angle_velocity + self.dt * action
|
||||
# self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
|
||||
self._angle_velocity = vel
|
||||
self._joint_angle = self._joint_angle + self.dt * self._angle_velocity
|
||||
self._update_joints()
|
||||
self._steps += 1
|
||||
|
||||
@ -111,7 +114,7 @@ class SimpleReacherEnv(gym.Env):
|
||||
# reward_dist = np.exp(-0.1 * diff ** 2).mean()
|
||||
# reward_dist = - (diff ** 2).mean()
|
||||
|
||||
reward_ctrl = (action ** 2).sum()
|
||||
reward_ctrl = 1e-5 * (action ** 2).sum()
|
||||
reward = reward_dist - reward_ctrl
|
||||
return reward, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
|
||||
@ -139,7 +142,7 @@ class SimpleReacherEnv(gym.Env):
|
||||
# Sample uniformly in circle with radius R around center of reacher.
|
||||
R = np.sum(self.link_lengths)
|
||||
r = R * np.sqrt(self.np_random.uniform())
|
||||
theta = self.np_random.uniform() * 2 * np.pi
|
||||
theta = np.pi/2 + 0.001 * np.random.randn() # self.np_random.uniform() * 2 * np.pi
|
||||
return center + r * np.stack([np.cos(theta), np.sin(theta)])
|
||||
|
||||
def seed(self, seed=None):
|
||||
@ -170,8 +173,8 @@ class SimpleReacherEnv(gym.Env):
|
||||
plt.xlim([-lim, lim])
|
||||
plt.ylim([-lim, lim])
|
||||
# plt.draw()
|
||||
# plt.pause(1e-4) pushes window to foreground, which is annoying.
|
||||
self.fig.canvas.flush_events()
|
||||
plt.pause(1e-4) # pushes window to foreground, which is annoying.
|
||||
# self.fig.canvas.flush_events()
|
||||
|
||||
def close(self):
|
||||
del self.fig
|
||||
|
@ -56,6 +56,7 @@ class AlrMpEnvSampler:
|
||||
|
||||
vals = defaultdict(list)
|
||||
for p in split_params:
|
||||
self.env.reset()
|
||||
obs, reward, done, info = self.env.step(p)
|
||||
vals['obs'].append(obs)
|
||||
vals['reward'].append(reward)
|
||||
@ -82,8 +83,9 @@ class AlrContextualMpEnvSampler:
|
||||
vals = defaultdict(list)
|
||||
for i in range(repeat):
|
||||
new_contexts = self.env.reset()
|
||||
|
||||
new_samples = dist.sample(new_contexts)
|
||||
vals['new_contexts'].append(new_contexts)
|
||||
new_samples, new_contexts = dist.sample(new_contexts)
|
||||
vals['new_samples'].append(new_samples)
|
||||
|
||||
obs, reward, done, info = self.env.step(new_samples)
|
||||
vals['obs'].append(obs)
|
||||
@ -92,7 +94,8 @@ class AlrContextualMpEnvSampler:
|
||||
vals['info'].append(info)
|
||||
|
||||
# do not return values above threshold
|
||||
return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples],\
|
||||
return np.vstack(vals['new_samples'])[:n_samples], np.vstack(vals['new_contexts'])[:n_samples], \
|
||||
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
||||
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ class DmpWrapper(MPWrapper):
|
||||
|
||||
def mp_rollout(self, action):
|
||||
# if self.mp.start_pos is None:
|
||||
self.mp.dmp_start_pos = self.env.init_qpos # start_pos
|
||||
self.mp.dmp_start_pos = self.env.init_qpos.reshape((1, self.num_dof)) # start_pos
|
||||
goal_pos, weight_matrix = self.goal_and_weights(action)
|
||||
self.mp.set_weights(weight_matrix, goal_pos)
|
||||
return self.mp.reference_trajectory(self.t)
|
||||
|
@ -22,7 +22,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
):
|
||||
super().__init__(env)
|
||||
|
||||
# self.num_dof = num_dof
|
||||
self.num_dof = num_dof
|
||||
# self.num_basis = num_basis
|
||||
# self.duration = duration # seconds
|
||||
|
||||
@ -50,6 +50,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
# for p, c in zip(params, contexts):
|
||||
for p in params:
|
||||
# self.configure(c)
|
||||
# context = self.reset()
|
||||
ob, reward, done, info = self.step(p)
|
||||
obs.append(ob)
|
||||
rewards.append(reward)
|
||||
|
@ -83,6 +83,6 @@ if __name__ == '__main__':
|
||||
# example_mujoco()
|
||||
# example_dmp()
|
||||
# example_async()
|
||||
# env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1)
|
||||
env = gym.make("alr_envs:SimpleReacherDMP-v1")
|
||||
env = gym.make("alr_envs:HoleReacherDMP-v0")
|
||||
# env = gym.make("alr_envs:SimpleReacherDMP-v1")
|
||||
print()
|
Loading…
Reference in New Issue
Block a user