dmp env wrappers initial
This commit is contained in:
parent
72b5e2bfc9
commit
a8fcbd6fb0
@ -63,15 +63,15 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
# register(
|
||||||
id='ALRReacherSparse-v0',
|
# id='ALRReacherSparse-v0',
|
||||||
entry_point='alr_envs.mujoco:ALRReacherEnv',
|
# entry_point='alr_envs.mujoco:ALRReacherEnv',
|
||||||
max_episode_steps=200,
|
# max_episode_steps=200,
|
||||||
kwargs={
|
# kwargs={
|
||||||
"steps_before_reward": 200,
|
# "steps_before_reward": 200,
|
||||||
"n_links": 7,
|
# "n_links": 7,
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRReacher7Short-v0',
|
id='ALRReacher7Short-v0',
|
||||||
|
291
alr_envs/classic_control/hole_reacher.py
Normal file
291
alr_envs/classic_control/hole_reacher.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import patches
|
||||||
|
|
||||||
|
|
||||||
|
def ccw(A, B, C):
|
||||||
|
return (C[1]-A[1]) * (B[0]-A[0]) - (B[1]-A[1]) * (C[0]-A[0]) > 1e-12
|
||||||
|
|
||||||
|
|
||||||
|
# Return true if line segments AB and CD intersect
|
||||||
|
def intersect(A, B, C, D):
|
||||||
|
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
||||||
|
|
||||||
|
|
||||||
|
class HoleReacher(gym.Env):
|
||||||
|
|
||||||
|
def __init__(self, num_links, hole_x, hole_width, hole_depth, allow_self_collision=False,
|
||||||
|
allow_wall_collision=False, collision_penalty=1000):
|
||||||
|
self.hole_x = hole_x # x-position of center of hole
|
||||||
|
self.hole_width = hole_width # width of hole
|
||||||
|
self.hole_depth = hole_depth # depth of hole
|
||||||
|
self.num_links = num_links
|
||||||
|
self.link_lengths = np.ones((num_links, 1))
|
||||||
|
self.bottom_center_of_hole = np.hstack([hole_x, -hole_depth])
|
||||||
|
self.top_center_of_hole = np.hstack([hole_x, 0])
|
||||||
|
self.left_wall_edge = np.hstack([hole_x - self.hole_width/2, 0])
|
||||||
|
self.right_wall_edge = np.hstack([hole_x + self.hole_width / 2, 0])
|
||||||
|
self.allow_self_collision = allow_self_collision
|
||||||
|
self.allow_wall_collision = allow_wall_collision
|
||||||
|
self.collision_penalty = collision_penalty
|
||||||
|
|
||||||
|
self._joints = None
|
||||||
|
self._joint_angles = None
|
||||||
|
self._angle_velocity = None
|
||||||
|
self.start_pos = np.hstack([[np.pi/2], np.zeros(self.num_links - 1)])
|
||||||
|
self.start_vel = np.zeros(self.num_links)
|
||||||
|
|
||||||
|
self._dt = 0.01
|
||||||
|
|
||||||
|
action_bound = np.pi * np.ones((self.num_links,))
|
||||||
|
state_bound = np.hstack([
|
||||||
|
[np.pi] * self.num_links, # cos
|
||||||
|
[np.pi] * self.num_links, # sin
|
||||||
|
[np.inf] * self.num_links, # velocity
|
||||||
|
[np.inf] * 2, # x-y coordinates of target distance
|
||||||
|
[np.inf] # env steps, because reward start after n steps TODO: Maybe
|
||||||
|
])
|
||||||
|
self.action_space = gym.spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape)
|
||||||
|
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||||
|
|
||||||
|
self.fig = None
|
||||||
|
rect_1 = patches.Rectangle((-self.num_links, -1),
|
||||||
|
self.num_links + self.hole_x - self.hole_width / 2, 1,
|
||||||
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
|
rect_2 = patches.Rectangle((self.hole_x + self.hole_width / 2, -1),
|
||||||
|
self.num_links - self.hole_x + self.hole_width / 2, 1,
|
||||||
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
|
rect_3 = patches.Rectangle((self.hole_x - self.hole_width / 2, -1), self.hole_width,
|
||||||
|
1 - self.hole_depth,
|
||||||
|
fill=True, edgecolor='k', facecolor='k')
|
||||||
|
|
||||||
|
self.patches = [rect_1, rect_2, rect_3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_effector(self):
|
||||||
|
return self._joints[self.num_links].T
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
# vel = (action - self._joint_angles) / self._dt
|
||||||
|
# acc = (vel - self._angle_velocity) / self._dt
|
||||||
|
# self._angle_velocity = vel
|
||||||
|
# self._joint_angles = action
|
||||||
|
|
||||||
|
vel = action
|
||||||
|
acc = (vel - self._angle_velocity) / self._dt
|
||||||
|
self._angle_velocity = vel
|
||||||
|
self._joint_angles = self._joint_angles + self._dt * self._angle_velocity
|
||||||
|
|
||||||
|
self._update_joints()
|
||||||
|
|
||||||
|
rew = self._reward()
|
||||||
|
|
||||||
|
rew -= 1e-6 * np.sum(acc**2)
|
||||||
|
|
||||||
|
if self._steps == 180:
|
||||||
|
rew -= (0.1 * np.sum(vel**2) ** 2
|
||||||
|
+ 1e-3 * np.sum(action**2)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._is_collided:
|
||||||
|
rew -= self.collision_penalty
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
self._steps += 1
|
||||||
|
|
||||||
|
return self._get_obs().copy(), rew, self._is_collided, info
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._joint_angles = self.start_pos
|
||||||
|
self._angle_velocity = self.start_vel
|
||||||
|
self._joints = np.zeros((self.num_links + 1, 2))
|
||||||
|
self._update_joints()
|
||||||
|
self._steps = 0
|
||||||
|
|
||||||
|
return self._get_obs().copy()
|
||||||
|
|
||||||
|
def _update_joints(self):
|
||||||
|
"""
|
||||||
|
update _joints to get new end effector position. The other links are only required for rendering.
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
line_points_in_taskspace = self.get_forward_kinematics(num_points_per_link=20)
|
||||||
|
|
||||||
|
self._joints[1:, 0] = self._joints[0, 0] + line_points_in_taskspace[:, -1, 0]
|
||||||
|
self._joints[1:, 1] = self._joints[0, 1] + line_points_in_taskspace[:, -1, 1]
|
||||||
|
|
||||||
|
self_collision = False
|
||||||
|
wall_collision = False
|
||||||
|
|
||||||
|
if not self.allow_self_collision:
|
||||||
|
self_collision = self.check_self_collision(line_points_in_taskspace)
|
||||||
|
if np.any(np.abs(self._joint_angles) > np.pi) and not self.allow_self_collision:
|
||||||
|
self_collision = True
|
||||||
|
|
||||||
|
if not self.allow_wall_collision:
|
||||||
|
wall_collision = self.check_wall_collision(line_points_in_taskspace)
|
||||||
|
|
||||||
|
self._is_collided = self_collision or wall_collision
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
theta = self._joint_angles
|
||||||
|
return np.hstack([
|
||||||
|
np.cos(theta),
|
||||||
|
np.sin(theta),
|
||||||
|
self._angle_velocity,
|
||||||
|
self.end_effector - self.bottom_center_of_hole,
|
||||||
|
self._steps
|
||||||
|
])
|
||||||
|
|
||||||
|
def _reward(self):
|
||||||
|
dist_reward = 0
|
||||||
|
if not self._is_collided:
|
||||||
|
if self._steps == 180:
|
||||||
|
dist_reward = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
|
||||||
|
else:
|
||||||
|
dist_reward = np.linalg.norm(self.end_effector - self.bottom_center_of_hole)
|
||||||
|
|
||||||
|
# TODO: make negative
|
||||||
|
out = - dist_reward ** 2
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_forward_kinematics(self, num_points_per_link=1):
|
||||||
|
theta = self._joint_angles[:, None]
|
||||||
|
|
||||||
|
if num_points_per_link > 1:
|
||||||
|
intermediate_points = np.linspace(0, 1, num_points_per_link)
|
||||||
|
else:
|
||||||
|
intermediate_points = 1
|
||||||
|
|
||||||
|
accumulated_theta = np.cumsum(theta, axis=0)
|
||||||
|
|
||||||
|
endeffector = np.zeros(shape=(self.num_links, num_points_per_link, 2))
|
||||||
|
|
||||||
|
x = np.cos(accumulated_theta) * self.link_lengths * intermediate_points
|
||||||
|
y = np.sin(accumulated_theta) * self.link_lengths * intermediate_points
|
||||||
|
|
||||||
|
endeffector[0, :, 0] = x[0, :]
|
||||||
|
endeffector[0, :, 1] = y[0, :]
|
||||||
|
|
||||||
|
for i in range(1, self.num_links):
|
||||||
|
endeffector[i, :, 0] = x[i, :] + endeffector[i - 1, -1, 0]
|
||||||
|
endeffector[i, :, 1] = y[i, :] + endeffector[i - 1, -1, 1]
|
||||||
|
|
||||||
|
return np.squeeze(endeffector + self._joints[0, :])
|
||||||
|
|
||||||
|
def check_self_collision(self, line_points):
|
||||||
|
for i, line1 in enumerate(line_points):
|
||||||
|
for line2 in line_points[i+2:, :, :]:
|
||||||
|
# if line1 != line2:
|
||||||
|
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_wall_collision(self, line_points):
|
||||||
|
|
||||||
|
# all points that are before the hole in x
|
||||||
|
r, c = np.where(line_points[:, :, 0] < (self.hole_x - self.hole_width / 2))
|
||||||
|
|
||||||
|
# check if any of those points are below surface
|
||||||
|
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
|
||||||
|
|
||||||
|
if nr_line_points_below_surface_before_hole > 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# all points that are after the hole in x
|
||||||
|
r, c = np.where(line_points[:, :, 0] > (self.hole_x + self.hole_width / 2))
|
||||||
|
|
||||||
|
# check if any of those points are below surface
|
||||||
|
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
|
||||||
|
|
||||||
|
if nr_line_points_below_surface_after_hole > 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# all points that are above the hole
|
||||||
|
r, c = np.where((line_points[:, :, 0] > (self.hole_x - self.hole_width / 2)) & (
|
||||||
|
line_points[:, :, 0] < (self.hole_x + self.hole_width / 2)))
|
||||||
|
|
||||||
|
# check if any of those points are below surface
|
||||||
|
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self.hole_depth)
|
||||||
|
|
||||||
|
if nr_line_points_below_surface_in_hole > 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
if self.fig is None:
|
||||||
|
self.fig = plt.figure()
|
||||||
|
plt.ion()
|
||||||
|
plt.pause(0.01)
|
||||||
|
else:
|
||||||
|
plt.figure(self.fig.number)
|
||||||
|
|
||||||
|
if mode == "human":
|
||||||
|
plt.cla()
|
||||||
|
plt.title(f"Iteration: {self._steps}, distance: {self.end_effector - self.bottom_center_of_hole}")
|
||||||
|
|
||||||
|
# Arm
|
||||||
|
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k')
|
||||||
|
|
||||||
|
# Add the patch to the Axes
|
||||||
|
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||||
|
|
||||||
|
lim = np.sum(self.link_lengths) + 0.5
|
||||||
|
plt.xlim([-lim, lim])
|
||||||
|
plt.ylim([-1.1, lim])
|
||||||
|
# plt.draw()
|
||||||
|
plt.pause(1e-4) # pushes window to foreground, which is annoying.
|
||||||
|
# self.fig.canvas.flush_events()
|
||||||
|
|
||||||
|
elif render_partial:
|
||||||
|
if t == 0:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
# Add the patch to the Axes
|
||||||
|
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||||
|
|
||||||
|
plt.xlim(-self.num_links, self.num_links), plt.ylim(-1, self.num_links)
|
||||||
|
|
||||||
|
if t % 20 == 0 or t == 199 or is_collided:
|
||||||
|
ax.plot(line_points_in_taskspace[:, 0, 0],
|
||||||
|
line_points_in_taskspace[:, 0, 1],
|
||||||
|
line_points_in_taskspace[:, -1, 0],
|
||||||
|
line_points_in_taskspace[:, -1, 1], marker='o', color='k', alpha=t / 200)
|
||||||
|
|
||||||
|
plt.pause(0.01)
|
||||||
|
|
||||||
|
elif render_final:
|
||||||
|
if self._steps == 0 or self._is_collided:
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
# Add the patch to the Axes
|
||||||
|
[plt.gca().add_patch(rect) for rect in self.patches]
|
||||||
|
|
||||||
|
plt.xlim(-self.num_links, self.num_links), plt.ylim(-1, self.num_links)
|
||||||
|
# Arm
|
||||||
|
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k')
|
||||||
|
|
||||||
|
plt.pause(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
nl = 5
|
||||||
|
env = HoleReacher(num_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=0.15, hole_depth=1, hole_x=1)
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
for i in range(200):
|
||||||
|
# objective.load_result("/tmp/cma")
|
||||||
|
# explore = np.random.multivariate_normal(mean=np.zeros(30), cov=1 * np.eye(30))
|
||||||
|
ac = 11 * env.action_space.sample()
|
||||||
|
# ac[0] += np.pi/2
|
||||||
|
obs, rew, done, info = env.step(ac)
|
||||||
|
env.render(mode="render_full")
|
||||||
|
|
||||||
|
print(rew)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
break
|
@ -6,6 +6,8 @@ import matplotlib.pyplot as plt
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
from alr_envs.utils.utils import angle_normalize
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("DISPLAY", None):
|
if os.environ.get("DISPLAY", None):
|
||||||
mpl.use('Qt5Agg')
|
mpl.use('Qt5Agg')
|
||||||
|
171
alr_envs/utils/dmp_async_vec_env.py
Normal file
171
alr_envs/utils/dmp_async_vec_env.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
import gym
|
||||||
|
from gym.error import (AlreadyPendingCallError, NoAsyncCallError)
|
||||||
|
from gym.vector.utils import concatenate, create_empty_array
|
||||||
|
from gym.vector.async_vector_env import AsyncState
|
||||||
|
import numpy as np
|
||||||
|
import multiprocessing as mp
|
||||||
|
from copy import deepcopy
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class DmpAsyncVectorEnv(gym.vector.AsyncVectorEnv):
|
||||||
|
def __init__(self, env_fns, n_samples, observation_space=None, action_space=None,
|
||||||
|
shared_memory=True, copy=True, context=None, daemon=True, worker=None):
|
||||||
|
super(DmpAsyncVectorEnv, self).__init__(env_fns,
|
||||||
|
observation_space=observation_space,
|
||||||
|
action_space=action_space,
|
||||||
|
shared_memory=shared_memory,
|
||||||
|
copy=copy,
|
||||||
|
context=context,
|
||||||
|
daemon=daemon,
|
||||||
|
worker=worker)
|
||||||
|
|
||||||
|
# we need to overwrite the number of samples as we may sample more than num_envs
|
||||||
|
self.observations = create_empty_array(self.single_observation_space,
|
||||||
|
n=n_samples,
|
||||||
|
fn=np.zeros)
|
||||||
|
|
||||||
|
def __call__(self, params):
|
||||||
|
return self.rollout(params)
|
||||||
|
|
||||||
|
def rollout_async(self, actions):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
actions : iterable of samples from `action_space`
|
||||||
|
List of actions.
|
||||||
|
"""
|
||||||
|
self._assert_is_running()
|
||||||
|
if self._state != AsyncState.DEFAULT:
|
||||||
|
raise AlreadyPendingCallError('Calling `rollout_async` while waiting '
|
||||||
|
'for a pending call to `{0}` to complete.'.format(
|
||||||
|
self._state.value), self._state.value)
|
||||||
|
|
||||||
|
# split_actions = np.array_split(actions, self.num_envs)
|
||||||
|
actions = np.atleast_2d(actions)
|
||||||
|
split_actions = np.array_split(actions, np.minimum(len(actions), self.num_envs))
|
||||||
|
for pipe, action in zip(self.parent_pipes, split_actions):
|
||||||
|
pipe.send(('rollout', action))
|
||||||
|
for pipe in self.parent_pipes[len(split_actions):]:
|
||||||
|
pipe.send(('idle', None))
|
||||||
|
self._state = AsyncState.WAITING_ROLLOUT
|
||||||
|
|
||||||
|
def rollout_wait(self, timeout=None):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
timeout : int or float, optional
|
||||||
|
Number of seconds before the call to `step_wait` times out. If
|
||||||
|
`None`, the call to `step_wait` never times out.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
observations : sample from `observation_space`
|
||||||
|
A batch of observations from the vectorized environment.
|
||||||
|
|
||||||
|
rewards : `np.ndarray` instance (dtype `np.float_`)
|
||||||
|
A vector of rewards from the vectorized environment.
|
||||||
|
|
||||||
|
dones : `np.ndarray` instance (dtype `np.bool_`)
|
||||||
|
A vector whose entries indicate whether the episode has ended.
|
||||||
|
|
||||||
|
infos : list of dict
|
||||||
|
A list of auxiliary diagnostic information.
|
||||||
|
"""
|
||||||
|
self._assert_is_running()
|
||||||
|
if self._state != AsyncState.WAITING_ROLLOUT:
|
||||||
|
raise NoAsyncCallError('Calling `rollout_wait` without any prior call '
|
||||||
|
'to `rollout_async`.', AsyncState.WAITING_ROLLOUT.value)
|
||||||
|
|
||||||
|
if not self._poll(timeout):
|
||||||
|
self._state = AsyncState.DEFAULT
|
||||||
|
raise mp.TimeoutError('The call to `rollout_wait` has timed out after '
|
||||||
|
'{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))
|
||||||
|
|
||||||
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||||
|
results = [r for r in results if r is not None]
|
||||||
|
self._raise_if_errors(successes)
|
||||||
|
self._state = AsyncState.DEFAULT
|
||||||
|
|
||||||
|
observations_list, rewards, dones, infos = [_flatten_list(r) for r in zip(*results)]
|
||||||
|
|
||||||
|
# if not self.shared_memory:
|
||||||
|
# self.observations = concatenate(observations_list, self.observations,
|
||||||
|
# self.single_observation_space)
|
||||||
|
|
||||||
|
# return (deepcopy(self.observations) if self.copy else self.observations,
|
||||||
|
# np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
||||||
|
|
||||||
|
return np.array(rewards)
|
||||||
|
|
||||||
|
def rollout(self, actions):
|
||||||
|
self.rollout_async(actions)
|
||||||
|
return self.rollout_wait()
|
||||||
|
|
||||||
|
|
||||||
|
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||||
|
assert shared_memory is None
|
||||||
|
env = env_fn()
|
||||||
|
parent_pipe.close()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
command, data = pipe.recv()
|
||||||
|
if command == 'reset':
|
||||||
|
observation = env.reset()
|
||||||
|
pipe.send((observation, True))
|
||||||
|
elif command == 'step':
|
||||||
|
observation, reward, done, info = env.step(data)
|
||||||
|
if done:
|
||||||
|
observation = env.reset()
|
||||||
|
pipe.send(((observation, reward, done, info), True))
|
||||||
|
elif command == 'rollout':
|
||||||
|
observations = []
|
||||||
|
rewards = []
|
||||||
|
dones = []
|
||||||
|
infos = []
|
||||||
|
for d in data:
|
||||||
|
env.reset()
|
||||||
|
observation, reward, done, info = env.step(d)
|
||||||
|
observations.append(observation)
|
||||||
|
rewards.append(reward)
|
||||||
|
dones.append(done)
|
||||||
|
infos.append(info)
|
||||||
|
pipe.send(((observations, rewards, dones, infos), (True, ) * len(rewards)))
|
||||||
|
elif command == 'seed':
|
||||||
|
env.seed(data)
|
||||||
|
pipe.send((None, True))
|
||||||
|
elif command == 'close':
|
||||||
|
pipe.send((None, True))
|
||||||
|
break
|
||||||
|
elif command == 'idle':
|
||||||
|
pipe.send((None, True))
|
||||||
|
elif command == '_check_observation_space':
|
||||||
|
pipe.send((data == env.observation_space, True))
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Received unknown command `{0}`. Must '
|
||||||
|
'be one of {`reset`, `step`, `seed`, `close`, '
|
||||||
|
'`_check_observation_space`}.'.format(command))
|
||||||
|
except (KeyboardInterrupt, Exception):
|
||||||
|
error_queue.put((index,) + sys.exc_info()[:2])
|
||||||
|
pipe.send((None, False))
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_obs(obs):
|
||||||
|
assert isinstance(obs, (list, tuple))
|
||||||
|
assert len(obs) > 0
|
||||||
|
|
||||||
|
if isinstance(obs[0], dict):
|
||||||
|
keys = obs[0].keys()
|
||||||
|
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
||||||
|
else:
|
||||||
|
return np.stack(obs)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_list(l):
|
||||||
|
assert isinstance(l, (list, tuple))
|
||||||
|
assert len(l) > 0
|
||||||
|
assert all([len(l_) > 0 for l_ in l])
|
||||||
|
|
||||||
|
return [l__ for l_ in l for l__ in l_]
|
102
alr_envs/utils/dmp_env_wrapper.py
Normal file
102
alr_envs/utils/dmp_env_wrapper.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||||
|
from mp_lib.basis import DMPBasisGenerator
|
||||||
|
from mp_lib import dmps
|
||||||
|
import numpy as np
|
||||||
|
import gym
|
||||||
|
|
||||||
|
|
||||||
|
class DmpEnvWrapperBase(gym.Wrapper):
|
||||||
|
def __init__(self, env, num_dof, num_basis, duration=1, dt=0.01, learn_goal=False):
|
||||||
|
super(DmpEnvWrapperBase, self).__init__(env)
|
||||||
|
self.num_dof = num_dof
|
||||||
|
self.num_basis = num_basis
|
||||||
|
self.dim = num_dof * num_basis
|
||||||
|
if learn_goal:
|
||||||
|
self.dim += num_dof
|
||||||
|
self.learn_goal = True
|
||||||
|
duration = duration # seconds
|
||||||
|
time_steps = int(duration / dt)
|
||||||
|
self.t = np.linspace(0, duration, time_steps)
|
||||||
|
|
||||||
|
phase_generator = ExpDecayPhaseGenerator(alpha_phase=5, duration=duration)
|
||||||
|
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=self.num_basis)
|
||||||
|
|
||||||
|
self.dmp = dmps.DMP(num_dof=num_dof,
|
||||||
|
basis_generator=basis_generator,
|
||||||
|
phase_generator=phase_generator,
|
||||||
|
num_time_steps=time_steps,
|
||||||
|
dt=dt
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dmp.dmp_start_pos = env.start_pos.reshape((1, num_dof))
|
||||||
|
|
||||||
|
dmp_weights = np.zeros((num_basis, num_dof))
|
||||||
|
dmp_goal_pos = np.zeros(num_dof)
|
||||||
|
|
||||||
|
self.dmp.set_weights(dmp_weights, dmp_goal_pos)
|
||||||
|
|
||||||
|
def goal_and_weights(self, params):
|
||||||
|
if len(params.shape) > 1:
|
||||||
|
assert params.shape[1] == self.dim
|
||||||
|
else:
|
||||||
|
assert len(params) == self.dim
|
||||||
|
params = np.reshape(params, [1, self.dim])
|
||||||
|
|
||||||
|
if self.learn_goal:
|
||||||
|
goal_pos = params[0, -self.num_dof:]
|
||||||
|
weight_matrix = np.reshape(params[:, :-self.num_dof], [self.num_basis, self.num_dof])
|
||||||
|
else:
|
||||||
|
goal_pos = None
|
||||||
|
weight_matrix = np.reshape(params, [self.num_basis, self.num_dof])
|
||||||
|
|
||||||
|
return goal_pos, weight_matrix
|
||||||
|
|
||||||
|
def step(self, action, render=False):
|
||||||
|
""" overwrite step function where action now is the weights and possible goal position"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DmpEnvWrapperAngle(DmpEnvWrapperBase):
|
||||||
|
def step(self, action, render=False):
|
||||||
|
goal_pos, weight_matrix = self.goal_and_weights(action)
|
||||||
|
self.dmp.set_weights(weight_matrix, goal_pos)
|
||||||
|
trajectory, velocities = self.dmp.reference_trajectory(self.t)
|
||||||
|
|
||||||
|
rews = []
|
||||||
|
|
||||||
|
for t, traj in enumerate(trajectory):
|
||||||
|
obs, rew, done, info = self.env.step(traj)
|
||||||
|
rews.append(rew)
|
||||||
|
if render:
|
||||||
|
self.env.render(mode="human")
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
reward = np.sum(rews)
|
||||||
|
done = True
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
class DmpEnvWrapperVel(DmpEnvWrapperBase):
|
||||||
|
def step(self, action, render=False):
|
||||||
|
goal_pos, weight_matrix = self.goal_and_weights(action)
|
||||||
|
weight_matrix *= 50
|
||||||
|
self.dmp.set_weights(weight_matrix, goal_pos)
|
||||||
|
trajectory, velocities = self.dmp.reference_trajectory(self.t)
|
||||||
|
|
||||||
|
rews = []
|
||||||
|
|
||||||
|
for t, vel in enumerate(velocities):
|
||||||
|
obs, rew, done, info = self.env.step(vel)
|
||||||
|
rews.append(rew)
|
||||||
|
if render:
|
||||||
|
self.env.render(mode="human")
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
reward = np.sum(rews)
|
||||||
|
info = {}
|
||||||
|
|
||||||
|
return obs, reward, done, info
|
77
dmp_env_wrapper_example.py
Normal file
77
dmp_env_wrapper_example.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapperVel
|
||||||
|
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv, _worker
|
||||||
|
from alr_envs.classic_control.hole_reacher import HoleReacher
|
||||||
|
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# env = gym.make('alr_envs:SimpleReacher-v0')
|
||||||
|
# env = HoleReacher(num_links=5,
|
||||||
|
# allow_self_collision=False,
|
||||||
|
# allow_wall_collision=True,
|
||||||
|
# hole_width=0.15,
|
||||||
|
# hole_depth=1,
|
||||||
|
# hole_x=1)
|
||||||
|
#
|
||||||
|
# env = DmpEnvWrapperVel(env,
|
||||||
|
# num_dof=5,
|
||||||
|
# num_basis=5,
|
||||||
|
# duration=2,
|
||||||
|
# dt=env._dt,
|
||||||
|
# learn_goal=True)
|
||||||
|
#
|
||||||
|
# params = np.hstack([50 * np.random.randn(25), np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4])])
|
||||||
|
#
|
||||||
|
# print(params)
|
||||||
|
#
|
||||||
|
# env.reset()
|
||||||
|
# obs, rew, done, info = env.step(params, render=True)
|
||||||
|
#
|
||||||
|
# print(env.env._joint_angles)
|
||||||
|
#
|
||||||
|
# print(rew)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def make_env(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the inital seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
"""
|
||||||
|
def _init():
|
||||||
|
env = HoleReacher(num_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.15,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=1)
|
||||||
|
|
||||||
|
env = DmpEnvWrapperVel(env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
duration=2,
|
||||||
|
dt=env._dt,
|
||||||
|
learn_goal=True)
|
||||||
|
env.seed(seed + rank)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
n_samples = 4
|
||||||
|
|
||||||
|
env = DmpAsyncVectorEnv([make_env(i) for i in range(4)],
|
||||||
|
n_samples=n_samples,
|
||||||
|
context="spawn",
|
||||||
|
shared_memory=False,
|
||||||
|
worker=_worker)
|
||||||
|
|
||||||
|
# params = np.random.randn(4, 25)
|
||||||
|
params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
|
||||||
|
|
||||||
|
# env.reset()
|
||||||
|
out = env.rollout(params)
|
||||||
|
|
||||||
|
print(out)
|
@ -4,8 +4,8 @@ import gym
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
env = gym.make('alr_envs:ALRReacher-v0')
|
# env = gym.make('alr_envs:ALRReacher-v0')
|
||||||
# env = gym.make('alr_envs:SimpleReacher-v0')
|
env = gym.make('alr_envs:SimpleReacher-v0')
|
||||||
# env = gym.make('alr_envs:ALRReacher7-v0')
|
# env = gym.make('alr_envs:ALRReacher7-v0')
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user