diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 8e46fa9..fe32c0b 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -1,7 +1,7 @@ from gym.envs.registration import register from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock -# from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper +# from alr_envs.utils.mps.dmp_wrapper import DmpWrapper # Mujoco diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index 3b382f9..45bc1c3 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -1,27 +1,34 @@ +from typing import Union + import gym -import numpy as np import matplotlib.pyplot as plt +import numpy as np +from gym.utils import seeding from matplotlib import patches + from alr_envs.classic_control.utils import check_self_collision +from alr_envs.utils.mps.mp_environments import MPEnv -class HoleReacher(gym.Env): - - def __init__(self, n_links, hole_x, hole_width, hole_depth, allow_self_collision=False, - allow_wall_collision=False, collision_penalty=1000): +class HoleReacher(MPEnv): + def __init__(self, n_links, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None, + hole_width: float = 1., random_start: bool = True, allow_self_collision: bool = False, + allow_wall_collision: bool = False, collision_penalty: bool = 1000): self.n_links = n_links self.link_lengths = np.ones((n_links, 1)) - # task - self.hole_x = hole_x # x-position of center of hole - self.hole_width = hole_width # width of hole - self.hole_depth = hole_depth # depth of hole + self.random_start = random_start - self.bottom_center_of_hole = np.hstack([hole_x, -hole_depth]) - self.top_center_of_hole = np.hstack([hole_x, 0]) - self.left_wall_edge = np.hstack([hole_x - self.hole_width / 2, 0]) - self.right_wall_edge = np.hstack([hole_x + self.hole_width / 2, 0]) + # provided initial parameters + self._hole_x = hole_x # x-position of center of hole + self._hole_width = hole_width # width of hole + self._hole_depth = hole_depth # depth of hole + + # temp containers to store current setting + self._tmp_hole_x = None + self._tmp_hole_width = None + self._tmp_hole_depth = None # collision self.allow_self_collision = allow_self_collision @@ -29,11 +36,11 @@ class HoleReacher(gym.Env): self.collision_penalty = collision_penalty # state - self._joints = None self._joint_angles = None self._angle_velocity = None - self.start_pos = np.hstack([[np.pi / 2], np.zeros(self.n_links - 1)]) - self.start_vel = np.zeros(self.n_links) + self._joints = None + self._start_pos = np.hstack([[np.pi / 2], np.zeros(self.n_links - 1)]) + self._start_vel = np.zeros(self.n_links) self.dt = 0.01 # self.time_limit = 2 @@ -43,35 +50,64 @@ class HoleReacher(gym.Env): [np.pi] * self.n_links, # cos [np.pi] * self.n_links, # sin [np.inf] * self.n_links, # velocity + [np.inf], # hole width + [np.inf], # hole depth [np.inf] * 2, # x-y coordinates of target distance [np.inf] # env steps, because reward start after n steps TODO: Maybe ]) self.action_space = gym.spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape) self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape) + plt.ion() self.fig = None - rect_1 = patches.Rectangle((-self.n_links, -1), - self.n_links + self.hole_x - self.hole_width / 2, 1, - fill=True, edgecolor='k', facecolor='k') - rect_2 = patches.Rectangle((self.hole_x + self.hole_width / 2, -1), - self.n_links - self.hole_x + self.hole_width / 2, 1, - fill=True, edgecolor='k', facecolor='k') - rect_3 = patches.Rectangle((self.hole_x - self.hole_width / 2, -1), self.hole_width, - 1 - self.hole_depth, - fill=True, edgecolor='k', facecolor='k') - self.patches = [rect_1, rect_2, rect_3] + self.seed() + + @property + def corrected_obs_index(self): + return np.hstack([ + [self.random_start] * self.n_links, # cos + [self.random_start] * self.n_links, # sin + [self.random_start] * self.n_links, # velocity + [self._hole_width is None], # hole width + [self._hole_depth is None], # hole width + [True] * 2, # x-y coordinates of target distance + [False] # env steps + ]) + + def seed(self, seed=None): + self.np_random, seed = seeding.np_random(seed) + return [seed] @property def end_effector(self): return self._joints[self.n_links].T - def configure(self, context): - pass + def _generate_hole(self): + hole_x = self.np_random.uniform(0.5, 3.5, 1) if self._hole_x is None else np.copy(self._hole_x) + hole_width = self.np_random.uniform(0.5, 0.1, 1) if self._hole_width is None else np.copy(self._hole_width) + # TODO we do not want this right now. + hole_depth = self.np_random.uniform(1, 1, 1) if self._hole_depth is None else np.copy(self._hole_depth) + + self.bottom_center_of_hole = np.hstack([hole_x, -hole_depth]) + self.top_center_of_hole = np.hstack([hole_x, 0]) + self.left_wall_edge = np.hstack([hole_x - hole_width / 2, 0]) + self.right_wall_edge = np.hstack([hole_x + hole_width / 2, 0]) + + return hole_x, hole_width, hole_depth def reset(self): - self._joint_angles = self.start_pos - self._angle_velocity = self.start_vel + if self.random_start: + # MAybe change more than dirst seed + first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4) + self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)]) + else: + self._joint_angles = self._start_pos + + self._tmp_hole_x, self._tmp_hole_width, self._tmp_hole_depth = self._generate_hole() + self.set_patches() + + self._angle_velocity = self._start_vel self._joints = np.zeros((self.n_links + 1, 2)) self._update_joints() self._steps = 0 @@ -96,15 +132,14 @@ class HoleReacher(gym.Env): success = False reward = 0 if not self._is_collided: + # return reward only in last time step if self._steps == 199: dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) reward = - dist ** 2 success = dist < 0.005 else: + # Episode terminates when colliding, hence return reward dist = np.linalg.norm(self.end_effector - self.bottom_center_of_hole) - # if self.collision_penalty != 0: - # reward = -self.collision_penalty - # else: reward = - dist ** 2 - self.collision_penalty reward -= 5e-8 * np.sum(acc ** 2) @@ -112,8 +147,6 @@ class HoleReacher(gym.Env): info = {"is_collided": self._is_collided, "is_success": success} self._steps += 1 - - # done = self._steps * self.dt > self.time_limit or self._is_collided done = self._is_collided return self._get_obs().copy(), reward, done, info @@ -148,6 +181,8 @@ class HoleReacher(gym.Env): np.cos(theta), np.sin(theta), self._angle_velocity, + self._hole_width, + self._hole_depth, self.end_effector - self.bottom_center_of_hole, self._steps ]) @@ -155,31 +190,26 @@ class HoleReacher(gym.Env): def get_forward_kinematics(self, num_points_per_link=1): theta = self._joint_angles[:, None] - if num_points_per_link > 1: - intermediate_points = np.linspace(0, 1, num_points_per_link) - else: - intermediate_points = 1 - + intermediate_points = np.linspace(0, 1, num_points_per_link) if num_points_per_link > 1 else 1 accumulated_theta = np.cumsum(theta, axis=0) - - endeffector = np.zeros(shape=(self.n_links, num_points_per_link, 2)) + end_effector = np.zeros(shape=(self.n_links, num_points_per_link, 2)) x = np.cos(accumulated_theta) * self.link_lengths * intermediate_points y = np.sin(accumulated_theta) * self.link_lengths * intermediate_points - endeffector[0, :, 0] = x[0, :] - endeffector[0, :, 1] = y[0, :] + end_effector[0, :, 0] = x[0, :] + end_effector[0, :, 1] = y[0, :] for i in range(1, self.n_links): - endeffector[i, :, 0] = x[i, :] + endeffector[i - 1, -1, 0] - endeffector[i, :, 1] = y[i, :] + endeffector[i - 1, -1, 1] + end_effector[i, :, 0] = x[i, :] + end_effector[i - 1, -1, 0] + end_effector[i, :, 1] = y[i, :] + end_effector[i - 1, -1, 1] - return np.squeeze(endeffector + self._joints[0, :]) + return np.squeeze(end_effector + self._joints[0, :]) def check_wall_collision(self, line_points): # all points that are before the hole in x - r, c = np.where(line_points[:, :, 0] < (self.hole_x - self.hole_width / 2)) + r, c = np.where(line_points[:, :, 0] < (self._tmp_hole_x - self._tmp_hole_width / 2)) # check if any of those points are below surface nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0) @@ -188,7 +218,7 @@ class HoleReacher(gym.Env): return True # all points that are after the hole in x - r, c = np.where(line_points[:, :, 0] > (self.hole_x + self.hole_width / 2)) + r, c = np.where(line_points[:, :, 0] > (self._tmp_hole_x + self._tmp_hole_width / 2)) # check if any of those points are below surface nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0) @@ -197,11 +227,11 @@ class HoleReacher(gym.Env): return True # all points that are above the hole - r, c = np.where((line_points[:, :, 0] > (self.hole_x - self.hole_width / 2)) & ( - line_points[:, :, 0] < (self.hole_x + self.hole_width / 2))) + r, c = np.where((line_points[:, :, 0] > (self._tmp_hole_x - self._tmp_hole_width / 2)) & ( + line_points[:, :, 0] < (self._tmp_hole_x + self._tmp_hole_width / 2))) # check if any of those points are below surface - nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self.hole_depth) + nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_hole_depth) if nr_line_points_below_surface_in_hole > 0: return True @@ -210,28 +240,33 @@ class HoleReacher(gym.Env): def render(self, mode='human'): if self.fig is None: + plt.ion() self.fig = plt.figure() - # plt.ion() - # plt.pause(0.01) - else: - plt.figure(self.fig.number) + ax = self.fig.add_subplot(1, 1, 1) + + # limits + lim = np.sum(self.link_lengths) + 0.5 + ax.set_xlim([-lim, lim]) + ax.set_ylim([-1.1, lim]) + + self.line, = ax.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k') + self.set_patches() + self.fig.show() if mode == "human": - plt.cla() - plt.title(f"Iteration: {self._steps}, distance: {self.end_effector - self.bottom_center_of_hole}") + self.fig.gca().set_title( + f"Iteration: {self._steps}, distance: {self.end_effector - self.bottom_center_of_hole}") # Arm plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k') - # Add the patch to the Axes - [plt.gca().add_patch(rect) for rect in self.patches] + # Arm + self.line.set_xdata(self._joints[:, 0]) + self.line.set_ydata(self._joints[:, 1]) - lim = np.sum(self.link_lengths) + 0.5 - plt.xlim([-lim, lim]) - plt.ylim([-1.1, lim]) - # plt.draw() - plt.pause(1e-4) # pushes window to foreground, which is annoying. - # self.fig.canvas.flush_events() + self.fig.canvas.draw() + self.fig.canvas.flush_events() + # self.fig.show() elif mode == "partial": if self._steps == 1: @@ -266,6 +301,24 @@ class HoleReacher(gym.Env): plt.pause(0.01) + def set_patches(self): + if self.fig is not None: + self.fig.gca().patches = [] + rect_1 = patches.Rectangle((-self.n_links, -1), self.n_links + self._tmp_hole_x - self._tmp_hole_width / 2, + 1, + fill=True, edgecolor='k', facecolor='k') + rect_2 = patches.Rectangle((self._tmp_hole_x + self._tmp_hole_width / 2, -1), + self.n_links - self._tmp_hole_x + self._tmp_hole_width / 2, 1, + fill=True, edgecolor='k', facecolor='k') + rect_3 = patches.Rectangle((self._tmp_hole_x - self._tmp_hole_width / 2, -1), self._tmp_hole_width, + 1 - self._tmp_hole_depth, + fill=True, edgecolor='k', facecolor='k') + + # Add the patch to the Axes + self.fig.gca().add_patch(rect_1) + self.fig.gca().add_patch(rect_2) + self.fig.gca().add_patch(rect_3) + def close(self): if self.fig is not None: plt.close(self.fig) @@ -274,8 +327,8 @@ class HoleReacher(gym.Env): if __name__ == '__main__': nl = 5 render_mode = "human" # "human" or "partial" or "final" - env = HoleReacher(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=0.15, - hole_depth=1, hole_x=1) + env = HoleReacher(n_links=nl, allow_self_collision=False, allow_wall_collision=False, hole_width=None, + hole_depth=1, hole_x=None) env.reset() # env.render(mode=render_mode) @@ -285,11 +338,13 @@ if __name__ == '__main__': ac = 2 * env.action_space.sample() # ac[0] += np.pi/2 obs, rew, d, info = env.step(ac) - env.render(mode=render_mode) + # if i % 1 == 0: + if i == 0: + env.render(mode=render_mode) print(rew) if d: - break + env.reset() env.close() diff --git a/alr_envs/mujoco/ball_in_a_cup/utils.py b/alr_envs/mujoco/ball_in_a_cup/utils.py index bfec3cf..d368e45 100644 --- a/alr_envs/mujoco/ball_in_a_cup/utils.py +++ b/alr_envs/mujoco/ball_in_a_cup/utils.py @@ -1,5 +1,5 @@ -from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper -from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper +from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper +from alr_envs.utils.mps.dmp_wrapper import DmpWrapper from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv diff --git a/alr_envs/mujoco/beerpong/utils.py b/alr_envs/mujoco/beerpong/utils.py index bfbc2a1..cdcbd13 100644 --- a/alr_envs/mujoco/beerpong/utils.py +++ b/alr_envs/mujoco/beerpong/utils.py @@ -1,4 +1,4 @@ -from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper +from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper from alr_envs.mujoco.beerpong.beerpong import ALRBeerpongEnv from alr_envs.mujoco.beerpong.beerpong_simple import ALRBeerpongEnv as ALRBeerpongEnvSimple diff --git a/alr_envs/utils/legacy/utils.py b/alr_envs/utils/legacy/utils.py index c158cae..fbcb34a 100644 --- a/alr_envs/utils/legacy/utils.py +++ b/alr_envs/utils/legacy/utils.py @@ -1,7 +1,7 @@ import alr_envs.classic_control.hole_reacher as hr import alr_envs.classic_control.viapoint_reacher as vpr -from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper -from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper +from alr_envs.utils.mps.dmp_wrapper import DmpWrapper +from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper import numpy as np @@ -65,7 +65,7 @@ def make_holereacher_env(rank, seed=0): dt=_env.dt, learn_goal=True, alpha_phase=2, - start_pos=_env.start_pos, + start_pos=_env._start_pos, policy_type="velocity", weights_scale=50, goal_scale=0.1 @@ -105,7 +105,7 @@ def make_holereacher_fix_goal_env(rank, seed=0): learn_goal=False, final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]), alpha_phase=2, - start_pos=_env.start_pos, + start_pos=_env._start_pos, policy_type="velocity", weights_scale=50, goal_scale=1 @@ -142,7 +142,7 @@ def make_holereacher_env_pmp(rank, seed=0): num_basis=5, width=0.02, policy_type="velocity", - start_pos=_env.start_pos, + start_pos=_env._start_pos, duration=2, post_traj_time=0, dt=_env.dt, diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index c0e55b4..d455496 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -1,5 +1,5 @@ -from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper -from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper +from alr_envs.utils.mps.dmp_wrapper import DmpWrapper +from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper import gym from gym.vector.utils import write_to_shared_memory import sys diff --git a/alr_envs/utils/wrapper/__init__.py b/alr_envs/utils/mps/__init__.py similarity index 100% rename from alr_envs/utils/wrapper/__init__.py rename to alr_envs/utils/mps/__init__.py diff --git a/alr_envs/utils/wrapper/detpmp_wrapper.py b/alr_envs/utils/mps/detpmp_wrapper.py similarity index 66% rename from alr_envs/utils/wrapper/detpmp_wrapper.py rename to alr_envs/utils/mps/detpmp_wrapper.py index 62b93d5..63de98b 100644 --- a/alr_envs/utils/wrapper/detpmp_wrapper.py +++ b/alr_envs/utils/mps/detpmp_wrapper.py @@ -2,17 +2,18 @@ import gym import numpy as np from mp_lib import det_promp -from alr_envs.utils.wrapper.mp_wrapper import MPWrapper +from alr_envs.utils.mps.mp_environments import MPEnv +from alr_envs.utils.mps.mp_wrapper import MPWrapper class DetPMPWrapper(MPWrapper): - def __init__(self, env, num_dof, num_basis, width, start_pos=None, duration=1, dt=0.01, post_traj_time=0., - policy_type=None, weights_scale=1, zero_start=False, zero_goal=False, **mp_kwargs): + def __init__(self, env: MPEnv, num_dof: int, num_basis: int, width: int, start_pos=None, duration: int = 1, + dt: float = 0.01, post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1., + zero_start: bool = False, zero_goal: bool = False, **mp_kwargs): # self.duration = duration # seconds - super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, - num_basis=num_basis, width=width, start_pos=start_pos, zero_start=zero_start, - zero_goal=zero_goal) + super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, num_basis=num_basis, + width=width, start_pos=start_pos, zero_start=zero_start, zero_goal=zero_goal, **mp_kwargs) action_bounds = np.inf * np.ones((self.mp.n_basis * self.mp.n_dof)) self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32) diff --git a/alr_envs/utils/wrapper/dmp_wrapper.py b/alr_envs/utils/mps/dmp_wrapper.py similarity index 53% rename from alr_envs/utils/wrapper/dmp_wrapper.py rename to alr_envs/utils/mps/dmp_wrapper.py index 2a198db..e42205e 100644 --- a/alr_envs/utils/wrapper/dmp_wrapper.py +++ b/alr_envs/utils/mps/dmp_wrapper.py @@ -1,19 +1,18 @@ -from mp_lib.phase import ExpDecayPhaseGenerator -from mp_lib.basis import DMPBasisGenerator -from mp_lib import dmps -import numpy as np import gym +import numpy as np +from mp_lib import dmps +from mp_lib.basis import DMPBasisGenerator +from mp_lib.phase import ExpDecayPhaseGenerator -from alr_envs.utils.wrapper.mp_wrapper import MPWrapper +from alr_envs.utils.mps.mp_environments import MPEnv +from alr_envs.utils.mps.mp_wrapper import MPWrapper class DmpWrapper(MPWrapper): - def __init__(self, env: gym.Env, num_dof: int, num_basis: int, - # start_pos: np.ndarray = None, - # final_pos: np.ndarray = None, + def __init__(self, env: MPEnv, num_dof: int, num_basis: int, duration: int = 1, alpha_phase: float = 2., dt: float = None, - learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0., + learn_goal: bool = False, post_traj_time: float = 0., weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3., policy_type: str = None, render_mode: str = None): @@ -23,8 +22,6 @@ class DmpWrapper(MPWrapper): env: num_dof: num_basis: - start_pos: - final_pos: duration: alpha_phase: dt: @@ -37,30 +34,17 @@ class DmpWrapper(MPWrapper): self.learn_goal = learn_goal dt = env.dt if hasattr(env, "dt") else dt assert dt is not None - # start_pos = start_pos if start_pos is not None else env.start_pos if hasattr(env, "start_pos") else None - # TODO: assert start_pos is not None # start_pos will be set in initialize, do we need this here? - # if learn_goal: - # final_pos = np.zeros_like(start_pos) # arbitrary, will be learned - # final_pos = np.zeros((1, num_dof)) # arbitrary, will be learned - # else: - # final_pos = final_pos if final_pos is not None else start_pos if return_to_start else None - # assert final_pos is not None self.t = np.linspace(0, duration, int(duration / dt)) self.goal_scale = goal_scale - super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, render_mode, - num_basis=num_basis, - # start_pos=start_pos, final_pos=final_pos, - alpha_phase=alpha_phase, - bandwidth_factor=bandwidth_factor) + super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, render_mode, + num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor) action_bounds = np.inf * np.ones((np.prod(self.mp.dmp_weights.shape) + (num_dof if learn_goal else 0))) self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32) - def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, - # start_pos: np.ndarray = None, - # final_pos: np.ndarray = None, - alpha_phase: float = 2., bandwidth_factor: float = 3.): + def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, alpha_phase: float = 2., + bandwidth_factor: int = 3): phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration) basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis, @@ -69,15 +53,6 @@ class DmpWrapper(MPWrapper): dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator, num_time_steps=int(duration / dt), dt=dt) - # dmp.dmp_start_pos = start_pos.reshape((1, num_dof)) - # in a contextual environment, the start_pos may be not fixed, set in mp_rollout? - # TODO: Should we set start_pos in init at all? It's only used after calling rollout anyway... - # dmp.dmp_start_pos = start_pos.reshape((1, num_dof)) if start_pos is not None else np.zeros((1, num_dof)) - - # weights = np.zeros((num_basis, num_dof)) - # goal_pos = np.zeros(num_dof) if self.learn_goal else final_pos - - # dmp.set_weights(weights, goal_pos) return dmp def goal_and_weights(self, params): @@ -87,18 +62,15 @@ class DmpWrapper(MPWrapper): if self.learn_goal: goal_pos = params[0, -self.mp.num_dimensions:] # [num_dof] params = params[:, :-self.mp.num_dimensions] # [1,num_dof] - # weight_matrix = np.reshape(params[:, :-self.num_dof], [self.num_basis, self.num_dof]) else: goal_pos = self.env.goal_pos # self.mp.dmp_goal_pos.flatten() assert goal_pos is not None - # weight_matrix = np.reshape(params, [self.num_basis, self.num_dof]) - weight_matrix = np.reshape(params, self.mp.dmp_weights.shape) + weight_matrix = np.reshape(params, self.mp.dmp_weights.shape) # [num_basis, num_dof] return goal_pos * self.goal_scale, weight_matrix * self.weights_scale def mp_rollout(self, action): - # if self.mp.start_pos is None: - self.mp.dmp_start_pos = self.env.init_qpos # start_pos + self.mp.dmp_start_pos = self.env.start_pos goal_pos, weight_matrix = self.goal_and_weights(action) self.mp.set_weights(weight_matrix, goal_pos) return self.mp.reference_trajectory(self.t) diff --git a/alr_envs/utils/mps/mp_environments.py b/alr_envs/utils/mps/mp_environments.py new file mode 100644 index 0000000..f720f2f --- /dev/null +++ b/alr_envs/utils/mps/mp_environments.py @@ -0,0 +1,33 @@ +from abc import abstractmethod +from typing import Union + +import gym +import numpy as np + + +class MPEnv(gym.Env): + + @property + @abstractmethod + def corrected_obs_index(self): + """Returns boolean value for each observation entry + whether the observation is returned by the DMP for the contextual case or not. + This effectively allows to filter unwanted or unnecessary observations from the full step-based case. + """ + raise NotImplementedError() + + @property + @abstractmethod + def start_pos(self) -> Union[float, int, np.ndarray]: + """ + Returns the current position of the joints + """ + raise NotImplementedError() + + @property + def goal_pos(self) -> Union[float, int, np.ndarray]: + """ + Returns the current final position of the joints for the MP. + By default this returns the starting position. + """ + return self.start_pos diff --git a/alr_envs/utils/wrapper/mp_wrapper.py b/alr_envs/utils/mps/mp_wrapper.py similarity index 68% rename from alr_envs/utils/wrapper/mp_wrapper.py rename to alr_envs/utils/mps/mp_wrapper.py index adeba55..621de00 100644 --- a/alr_envs/utils/wrapper/mp_wrapper.py +++ b/alr_envs/utils/mps/mp_wrapper.py @@ -1,32 +1,18 @@ from abc import ABC, abstractmethod -from collections import defaultdict import gym import numpy as np +from alr_envs.utils.mps.mp_environments import MPEnv from alr_envs.utils.policies import get_policy_class class MPWrapper(gym.Wrapper, ABC): - def __init__(self, - env: gym.Env, - num_dof: int, - duration: int = 1, - dt: float = None, - post_traj_time: float = 0., - policy_type: str = None, - weights_scale: float = 1., - render_mode: str = None, - **mp_kwargs - ): + def __init__(self, env: MPEnv, num_dof: int, dt: float, duration: int = 1, post_traj_time: float = 0., + policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs): super().__init__(env) - # self.num_dof = num_dof - # self.num_basis = num_basis - # self.duration = duration # seconds - - # dt = env.dt if hasattr(env, "dt") else dt assert dt is not None # this should never happen as MPWrapper is a base class self.post_traj_steps = int(post_traj_time / dt) @@ -40,8 +26,11 @@ class MPWrapper(gym.Wrapper, ABC): self.render_mode = render_mode self.render_kwargs = {} - # TODO: not yet final + # TODO: @Max I think this should not be in this class, this functionality should be part of your sampler. def __call__(self, params, contexts=None): + """ + Can be used to provide a batch of parameter sets + """ params = np.atleast_2d(params) obs = [] rewards = [] @@ -63,7 +52,7 @@ class MPWrapper(gym.Wrapper, ABC): def reset(self): obs = self.env.reset() - return obs + return obs[self.env] def step(self, action: np.ndarray): """ This function generates a trajectory based on a DMP and then does the usual loop over reset and step""" @@ -77,15 +66,9 @@ class MPWrapper(gym.Wrapper, ABC): # self._velocity = velocity rewards = 0 - # infos = defaultdict(list) - - # TODO: @Max Why do we need this configure, states should be part of the model - # TODO: Ask Onur if the context distribution needs to be outside the environment - # TODO: For now create a new env with each context - # TODO: Explicitly call reset before step to obtain context from obs? - # self.env.configure(context) - # obs = self.env.reset() info = {} + # create random obs as the reset function is called externally + obs = self.env.observation_space.sample() for t, pos_vel in enumerate(zip(trajectory, velocity)): ac = self.policy.get_action(pos_vel[0], pos_vel[1]) @@ -107,18 +90,6 @@ class MPWrapper(gym.Wrapper, ABC): self.render_mode = mode self.render_kwargs = kwargs - # def __call__(self, actions): - # return self.step(actions) - # params = np.atleast_2d(params) - # rewards = [] - # infos = [] - # for p, c in zip(params, contexts): - # reward, info = self.rollout(p, c) - # rewards.append(reward) - # infos.append(info) - # - # return np.array(rewards), infos - @abstractmethod def mp_rollout(self, action): """