context wip

This commit is contained in:
Maximilian Huettenrauch 2021-05-17 09:32:51 +02:00
parent b4ad3e6ddd
commit 7f512068c9
7 changed files with 32 additions and 18 deletions

View File

@ -26,7 +26,7 @@ class EpisodicSimpleReacherEnv(SimpleReacherEnv):
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
@property
def start_pos(self):
def init_qpos(self):
return self._start_pos
# @property

View File

@ -62,6 +62,10 @@ class HoleReacher(gym.Env):
self.patches = [rect_1, rect_2, rect_3]
@property
def init_qpos(self):
return self.start_pos
@property
def end_effector(self):
return self._joints[self.n_links].T
@ -82,7 +86,7 @@ class HoleReacher(gym.Env):
"""
a single step with an action in joint velocity space
"""
vel = action # + 0.01 * np.random.randn(self.num_links)
vel = action # + 0.05 * np.random.randn(self.n_links)
acc = (vel - self._angle_velocity) / self.dt
self._angle_velocity = vel
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
@ -237,7 +241,10 @@ class HoleReacher(gym.Env):
if self._steps == 1:
# fig, ax = plt.subplots()
# Add the patch to the Axes
[plt.gca().add_patch(rect) for rect in self.patches]
try:
[plt.gca().add_patch(rect) for rect in self.patches]
except RuntimeError:
pass
# plt.pause(0.01)
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:

View File

@ -22,7 +22,7 @@ class SimpleReacherEnv(gym.Env):
super().__init__()
self.link_lengths = np.ones(n_links)
self.n_links = n_links
self.dt = 0.1
self.dt = 0.01
self.random_start = random_start
@ -56,10 +56,13 @@ class SimpleReacherEnv(gym.Env):
def step(self, action: np.ndarray):
# action = self._add_action_noise(action)
action = np.clip(action, -self.max_torque, self.max_torque)
# action = np.clip(action, -self.max_torque, self.max_torque)
vel = action
self._angle_velocity = self._angle_velocity + self.dt * action
self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
# self._angle_velocity = self._angle_velocity + self.dt * action
# self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
self._angle_velocity = vel
self._joint_angle = self._joint_angle + self.dt * self._angle_velocity
self._update_joints()
self._steps += 1
@ -111,7 +114,7 @@ class SimpleReacherEnv(gym.Env):
# reward_dist = np.exp(-0.1 * diff ** 2).mean()
# reward_dist = - (diff ** 2).mean()
reward_ctrl = (action ** 2).sum()
reward_ctrl = 1e-5 * (action ** 2).sum()
reward = reward_dist - reward_ctrl
return reward, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
@ -139,7 +142,7 @@ class SimpleReacherEnv(gym.Env):
# Sample uniformly in circle with radius R around center of reacher.
R = np.sum(self.link_lengths)
r = R * np.sqrt(self.np_random.uniform())
theta = self.np_random.uniform() * 2 * np.pi
theta = np.pi/2 + 0.001 * np.random.randn() # self.np_random.uniform() * 2 * np.pi
return center + r * np.stack([np.cos(theta), np.sin(theta)])
def seed(self, seed=None):
@ -170,8 +173,8 @@ class SimpleReacherEnv(gym.Env):
plt.xlim([-lim, lim])
plt.ylim([-lim, lim])
# plt.draw()
# plt.pause(1e-4) pushes window to foreground, which is annoying.
self.fig.canvas.flush_events()
plt.pause(1e-4) # pushes window to foreground, which is annoying.
# self.fig.canvas.flush_events()
def close(self):
del self.fig

View File

@ -56,6 +56,7 @@ class AlrMpEnvSampler:
vals = defaultdict(list)
for p in split_params:
self.env.reset()
obs, reward, done, info = self.env.step(p)
vals['obs'].append(obs)
vals['reward'].append(reward)
@ -82,8 +83,9 @@ class AlrContextualMpEnvSampler:
vals = defaultdict(list)
for i in range(repeat):
new_contexts = self.env.reset()
new_samples = dist.sample(new_contexts)
vals['new_contexts'].append(new_contexts)
new_samples, new_contexts = dist.sample(new_contexts)
vals['new_samples'].append(new_samples)
obs, reward, done, info = self.env.step(new_samples)
vals['obs'].append(obs)
@ -92,7 +94,8 @@ class AlrContextualMpEnvSampler:
vals['info'].append(info)
# do not return values above threshold
return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples],\
return np.vstack(vals['new_samples'])[:n_samples], np.vstack(vals['new_contexts'])[:n_samples], \
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]

View File

@ -98,7 +98,7 @@ class DmpWrapper(MPWrapper):
def mp_rollout(self, action):
# if self.mp.start_pos is None:
self.mp.dmp_start_pos = self.env.init_qpos # start_pos
self.mp.dmp_start_pos = self.env.init_qpos.reshape((1, self.num_dof)) # start_pos
goal_pos, weight_matrix = self.goal_and_weights(action)
self.mp.set_weights(weight_matrix, goal_pos)
return self.mp.reference_trajectory(self.t)

View File

@ -22,7 +22,7 @@ class MPWrapper(gym.Wrapper, ABC):
):
super().__init__(env)
# self.num_dof = num_dof
self.num_dof = num_dof
# self.num_basis = num_basis
# self.duration = duration # seconds
@ -50,6 +50,7 @@ class MPWrapper(gym.Wrapper, ABC):
# for p, c in zip(params, contexts):
for p in params:
# self.configure(c)
# context = self.reset()
ob, reward, done, info = self.step(p)
obs.append(ob)
rewards.append(reward)

View File

@ -83,6 +83,6 @@ if __name__ == '__main__':
# example_mujoco()
# example_dmp()
# example_async()
# env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1)
env = gym.make("alr_envs:SimpleReacherDMP-v1")
env = gym.make("alr_envs:HoleReacherDMP-v0")
# env = gym.make("alr_envs:SimpleReacherDMP-v1")
print()