From 1f5a7b67f52a1842ec7fdfc612cc7de0f0b6e1c9 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 25 Jan 2022 15:23:57 +0100 Subject: [PATCH 1/4] Added ALRReacherProMP --- alr_envs/alr/__init__.py | 54 ++++++++++++++++++++- alr_envs/alr/mujoco/reacher/__init__.py | 1 + alr_envs/alr/mujoco/reacher/alr_reacher.py | 32 +++++++++--- alr_envs/alr/mujoco/reacher/mp_wrapper.py | 31 ++++++++++++ alr_envs/examples/pd_control_gain_tuning.py | 50 ++++++++++++------- 5 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 alr_envs/alr/mujoco/reacher/mp_wrapper.py diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 90ec78c..53f292c 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv from .mujoco.reacher.alr_reacher import ALRReacherEnv from .mujoco.reacher.balancing import BalancingEnv -from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS +from .mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} @@ -363,6 +363,58 @@ for _v in _versions: } ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + +## ALRReacher +_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"] +for _v in _versions: + _name = _v.split("-") + _env_id = f'{_name[0]}DMP-{_name[1]}' + register( + id=_env_id, + entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', + # max_episode_steps=1, + kwargs={ + "name": f"alr_envs:{_v}", + "wrappers": [mujoco.reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5 if "long" not in _v.lower() else 7, + "num_basis": 5, + "duration": 4, + "alpha_phase": 2, + "learn_goal": True, + "policy_type": "motor", + "weights_scale": 1, + "policy_kwargs": { + "p_gains": 1, + "d_gains": 0.1 + } + } + } + ) + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) + + _env_id = f'{_name[0]}ProMP-{_name[1]}' + register( + id=_env_id, + entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', + kwargs={ + "name": f"alr_envs:{_v}", + "wrappers": [mujoco.reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5 if "long" not in _v.lower() else 7, + "num_basis": 5, + "duration": 4, + "policy_type": "motor", + "weights_scale": 1, + "zero_start": True, + "policy_kwargs": { + "p_gains": 1, + "d_gains": 0.1 + } + } + } + ) + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ## Beerpong _versions = ["v0", "v1", "v2", "v3"] diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index 2d122d2..e21eaed 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -42,7 +42,10 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): if self._steps >= self.steps_before_reward: vec = self.get_body_com("fingertip") - self.get_body_com("target") reward_dist -= self.reward_weight * np.linalg.norm(vec) - angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) + if self.steps_before_reward > 0: + # avoid giving this penalty for normal step based case + angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) + # angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum() reward_ctrl = - np.square(a).sum() if self.balance: @@ -61,14 +64,29 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): def viewer_setup(self): self.viewer.cam.trackbodyid = 0 + # def reset_model(self): + # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos + # while True: + # self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + # if np.linalg.norm(self.goal) < self.n_links / 10: + # break + # qpos[-2:] = self.goal + # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + # qvel[-2:] = 0 + # self.set_state(qpos, qvel) + # self._steps = 0 + # + # return self._get_obs() + def reset_model(self): - qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos - while True: - self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) - if np.linalg.norm(self.goal) < self.n_links / 10: - break + qpos = self.init_qpos + if not hasattr(self, "goal"): + while True: + self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + if np.linalg.norm(self.goal) < self.n_links / 10: + break qpos[-2:] = self.goal - qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + qvel = self.init_qvel qvel[-2:] = 0 self.set_state(qpos, qvel) self._steps = 0 diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py new file mode 100644 index 0000000..027a2e2 --- /dev/null +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -0,0 +1,31 @@ +from typing import Union + +import numpy as np +from mp_env_api import MPEnvWrapper + + +class MPWrapper(MPEnvWrapper): + + @property + def active_obs(self): + return np.concatenate([ + [True] * self.n_links, # cos + [True] * self.n_links, # sin + [True] * 2, # goal position + [True] * self.n_links, # angular velocity + [True] * 3, # goal distance + # self.get_body_com("target"), # only return target to make problem harder + [False], # step + ]) + + @property + def current_vel(self) -> Union[float, int, np.ndarray]: + return self.sim.data.qvel.flat[:self.n_links] + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.sim.data.qpos.flat[:self.n_links] + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index 90aac11..bdcaa41 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -2,36 +2,46 @@ import numpy as np from matplotlib import pyplot as plt from alr_envs import dmc, meta +from alr_envs.alr import mujoco from alr_envs.utils.make_env_helpers import make_promp_env + +def visualize(env): + t = env.t + pos_features = env.mp.basis_generator.basis(t) + plt.plot(t, pos_features) + plt.show() + # This might work for some environments, however, please verify either way the correct trajectory information # for your environment are extracted below -SEED = 10 -env_id = "ball_in_cup-catch" -wrappers = [dmc.ball_in_cup.MPWrapper] +SEED = 1 +# env_id = "ball_in_cup-catch" +env_id = "ALRReacherSparse-v0" +wrappers = [mujoco.reacher.MPWrapper] mp_kwargs = { - "num_dof": 2, - "num_basis": 10, - "duration": 2, - "width": 0.025, + "num_dof": 5, + "num_basis": 8, + "duration": 4, "policy_type": "motor", "weights_scale": 1, "zero_start": True, "policy_kwargs": { "p_gains": 1, - "d_gains": 1 + "d_gains": 0.1 } } -kwargs = dict(time_limit=2, episode_length=100) +# kwargs = dict(time_limit=4, episode_length=200) +kwargs = {} -env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, - **kwargs) +env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs) # Plot difference between real trajectory and target MP trajectory env.reset() -pos, vel = env.mp_rollout(env.action_space.sample()) +w = env.action_space.sample() * 10 +visualize(env) +pos, vel = env.mp_rollout(w) base_shape = env.full_action_space.shape actual_pos = np.zeros((len(pos), *base_shape)) @@ -51,18 +61,22 @@ plt.figure(figsize=(15, 5)) plt.subplot(131) plt.title("Position") -plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))]) +p1 = plt.plot(actual_pos, c='C0', label="true") # plt.plot(actual_pos_ball, label="true pos ball") -plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))]) plt.xlabel("Episode steps") -plt.legend() +# plt.legend() +handles, labels = plt.gca().get_legend_handles_labels() +from collections import OrderedDict + +by_label = OrderedDict(zip(labels, handles)) +plt.legend(by_label.values(), by_label.keys()) plt.subplot(132) plt.title("Velocity") -plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))]) -plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +plt.plot(actual_vel, c='C0', label="true") +plt.plot(vel, c='C1', label="MP") plt.xlabel("Episode steps") -plt.legend() plt.subplot(133) plt.title("Actions") From d313795cec6f18a3c6e694b54c7863efd2c92ae2 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 7 Apr 2022 14:40:43 +0200 Subject: [PATCH 2/4] reacher adjustments --- alr_envs/alr/__init__.py | 12 +- alr_envs/alr/mujoco/reacher/alr_reacher.py | 60 ++++++---- .../mujoco/reacher/assets/reacher_5links.xml | 107 +++++++++--------- alr_envs/alr/mujoco/reacher/mp_wrapper.py | 22 +++- alr_envs/examples/pd_control_gain_tuning.py | 38 +++++-- 5 files changed, 141 insertions(+), 98 deletions(-) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 53f292c..8a7140d 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -357,13 +357,13 @@ for _v in _versions: "num_basis": 5, "duration": 2, "policy_type": "velocity", - "weights_scale": 0.1, + "weights_scale": 5, "zero_start": True } } ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) - + ## ALRReacher _versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"] for _v in _versions: @@ -378,12 +378,12 @@ for _v in _versions: "wrappers": [mujoco.reacher.MPWrapper], "mp_kwargs": { "num_dof": 5 if "long" not in _v.lower() else 7, - "num_basis": 5, + "num_basis": 2, "duration": 4, "alpha_phase": 2, "learn_goal": True, "policy_type": "motor", - "weights_scale": 1, + "weights_scale": 5, "policy_kwargs": { "p_gains": 1, "d_gains": 0.1 @@ -402,10 +402,10 @@ for _v in _versions: "wrappers": [mujoco.reacher.MPWrapper], "mp_kwargs": { "num_dof": 5 if "long" not in _v.lower() else 7, - "num_basis": 5, + "num_basis": 1, "duration": 4, "policy_type": "motor", - "weights_scale": 1, + "weights_scale": 5, "zero_start": True, "policy_kwargs": { "p_gains": 1, diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index e21eaed..c2b5f18 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -44,9 +44,9 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): reward_dist -= self.reward_weight * np.linalg.norm(vec) if self.steps_before_reward > 0: # avoid giving this penalty for normal step based case - angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) - # angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum() - reward_ctrl = - np.square(a).sum() + # angular_vel -= 10 * np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) + angular_vel -= 10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum() + reward_ctrl = - 10 * np.square(a).sum() if self.balance: reward_balance -= self.balance_weight * np.abs( @@ -64,6 +64,35 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): def viewer_setup(self): self.viewer.cam.trackbodyid = 0 + # def reset_model(self): + # qpos = self.init_qpos + # if not hasattr(self, "goal"): + # self.goal = np.array([-0.25, 0.25]) + # # self.goal = self.init_qpos.copy()[:2] + 0.05 + # qpos[-2:] = self.goal + # qvel = self.init_qvel + # qvel[-2:] = 0 + # self.set_state(qpos, qvel) + # self._steps = 0 + # + # return self._get_obs() + + def reset_model(self): + qpos = self.init_qpos.copy() + while True: + self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + # self.goal = self.np_random.uniform(low=0, high=self.n_links / 10, size=2) + # self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=[0, self.n_links / 10], size=2) + if np.linalg.norm(self.goal) < self.n_links / 10: + break + qpos[-2:] = self.goal + qvel = self.init_qvel.copy() + qvel[-2:] = 0 + self.set_state(qpos, qvel) + self._steps = 0 + + return self._get_obs() + # def reset_model(self): # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos # while True: @@ -78,30 +107,15 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): # # return self._get_obs() - def reset_model(self): - qpos = self.init_qpos - if not hasattr(self, "goal"): - while True: - self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) - if np.linalg.norm(self.goal) < self.n_links / 10: - break - qpos[-2:] = self.goal - qvel = self.init_qvel - qvel[-2:] = 0 - self.set_state(qpos, qvel) - self._steps = 0 - - return self._get_obs() - def _get_obs(self): theta = self.sim.data.qpos.flat[:self.n_links] + target = self.get_body_com("target") return np.concatenate([ np.cos(theta), np.sin(theta), - self.sim.data.qpos.flat[self.n_links:], # this is goal position - self.sim.data.qvel.flat[:self.n_links], # this is angular velocity - self.get_body_com("fingertip") - self.get_body_com("target"), - # self.get_body_com("target"), # only return target to make problem harder + target[:2], # x-y of goal position + self.sim.data.qvel.flat[:self.n_links], # angular velocity + self.get_body_com("fingertip") - target, # goal distance [self._steps], ]) @@ -122,4 +136,4 @@ if __name__ == '__main__': if d: env.reset() - env.close() \ No newline at end of file + env.close() diff --git a/alr_envs/alr/mujoco/reacher/assets/reacher_5links.xml b/alr_envs/alr/mujoco/reacher/assets/reacher_5links.xml index 07be257..25a3208 100644 --- a/alr_envs/alr/mujoco/reacher/assets/reacher_5links.xml +++ b/alr_envs/alr/mujoco/reacher/assets/reacher_5links.xml @@ -1,54 +1,57 @@ - - - - - - \ No newline at end of file diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py index 027a2e2..abcdc50 100644 --- a/alr_envs/alr/mujoco/reacher/mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -9,15 +9,27 @@ class MPWrapper(MPEnvWrapper): @property def active_obs(self): return np.concatenate([ - [True] * self.n_links, # cos - [True] * self.n_links, # sin + [False] * self.n_links, # cos + [False] * self.n_links, # sin [True] * 2, # goal position - [True] * self.n_links, # angular velocity - [True] * 3, # goal distance + [False] * self.n_links, # angular velocity + [False] * 3, # goal distance # self.get_body_com("target"), # only return target to make problem harder - [False], # step + [False], # step ]) + # @property + # def active_obs(self): + # return np.concatenate([ + # [True] * self.n_links, # cos, True + # [True] * self.n_links, # sin, True + # [True] * 2, # goal position + # [True] * self.n_links, # angular velocity, True + # [True] * 3, # goal distance + # # self.get_body_com("target"), # only return target to make problem harder + # [False], # step + # ]) + @property def current_vel(self) -> Union[float, int, np.ndarray]: return self.sim.data.qvel.flat[:self.n_links] diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index bdcaa41..eff432a 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -12,34 +12,38 @@ def visualize(env): plt.plot(t, pos_features) plt.show() + # This might work for some environments, however, please verify either way the correct trajectory information # for your environment are extracted below SEED = 1 # env_id = "ball_in_cup-catch" env_id = "ALRReacherSparse-v0" +env_id = "button-press-v2" wrappers = [mujoco.reacher.MPWrapper] +wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper] mp_kwargs = { - "num_dof": 5, - "num_basis": 8, - "duration": 4, - "policy_type": "motor", - "weights_scale": 1, + "num_dof": 4, + "num_basis": 5, + "duration": 6.25, + "policy_type": "metaworld", + "weights_scale": 10, "zero_start": True, - "policy_kwargs": { - "p_gains": 1, - "d_gains": 0.1 - } + # "policy_kwargs": { + # "p_gains": 1, + # "d_gains": 0.1 + # } } # kwargs = dict(time_limit=4, episode_length=200) kwargs = {} env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs) +env.action_space.seed(SEED) # Plot difference between real trajectory and target MP trajectory env.reset() -w = env.action_space.sample() * 10 +w = env.action_space.sample() # N(0,1) visualize(env) pos, vel = env.mp_rollout(w) @@ -48,14 +52,24 @@ actual_pos = np.zeros((len(pos), *base_shape)) actual_vel = np.zeros((len(pos), *base_shape)) act = np.zeros((len(pos), *base_shape)) +plt.ion() +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +img = ax.imshow(env.env.render("rgb_array")) +fig.show() + for t, pos_vel in enumerate(zip(pos, vel)): actions = env.policy.get_action(pos_vel[0], pos_vel[1]) actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high) _, _, _, _ = env.env.step(actions) + if t % 15 == 0: + img.set_data(env.env.render("rgb_array")) + fig.canvas.draw() + fig.canvas.flush_events() act[t, :] = actions # TODO verify for your environment actual_pos[t, :] = env.current_pos - actual_vel[t, :] = env.current_vel + actual_vel[t, :] = 0 # env.current_vel plt.figure(figsize=(15, 5)) @@ -79,7 +93,7 @@ plt.plot(vel, c='C1', label="MP") plt.xlabel("Episode steps") plt.subplot(133) -plt.title("Actions") +plt.title(f"Actions {np.std(act, axis=0)}") plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))]) plt.xlabel("Episode steps") # plt.legend() From 1881c14a48b74b142f939a7a9edd1c47832614ac Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 5 May 2022 16:48:59 +0200 Subject: [PATCH 3/4] reacher adjustments --- alr_envs/alr/__init__.py | 5 +++-- alr_envs/alr/mujoco/reacher/alr_reacher.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 8a7140d..2e23d44 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -97,6 +97,7 @@ register( "hole_depth": 1, "hole_x": None, "collision_penalty": 100, + "rew_fct": "unbounded" } ) @@ -354,7 +355,7 @@ for _v in _versions: "wrappers": [classic_control.hole_reacher.MPWrapper], "mp_kwargs": { "num_dof": 5, - "num_basis": 5, + "num_basis": 3, "duration": 2, "policy_type": "velocity", "weights_scale": 5, @@ -402,7 +403,7 @@ for _v in _versions: "wrappers": [mujoco.reacher.MPWrapper], "mp_kwargs": { "num_dof": 5 if "long" not in _v.lower() else 7, - "num_basis": 1, + "num_basis": 2, "duration": 4, "policy_type": "motor", "weights_scale": 5, diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index c2b5f18..b436fdd 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -39,14 +39,18 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): reward_dist = 0.0 angular_vel = 0.0 reward_balance = 0.0 + is_delayed = self.steps_before_reward > 0 + reward_ctrl = - np.square(a).sum() if self._steps >= self.steps_before_reward: vec = self.get_body_com("fingertip") - self.get_body_com("target") reward_dist -= self.reward_weight * np.linalg.norm(vec) - if self.steps_before_reward > 0: + if is_delayed: # avoid giving this penalty for normal step based case # angular_vel -= 10 * np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) angular_vel -= 10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum() - reward_ctrl = - 10 * np.square(a).sum() + if is_delayed: + # Higher control penalty for sparse reward per timestep + reward_ctrl *= 10 if self.balance: reward_balance -= self.balance_weight * np.abs( From bd4632af8431faa19e698f2f572f6ea87ee0c54e Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 5 May 2022 16:54:39 +0200 Subject: [PATCH 4/4] hole_reacher update --- .../hole_reacher/hole_reacher.py | 3 + .../hole_reacher/hr_unbounded_reward.py | 60 +++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 alr_envs/alr/classic_control/hole_reacher/hr_unbounded_reward.py diff --git a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py index 208f005..883be8c 100644 --- a/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py +++ b/alr_envs/alr/classic_control/hole_reacher/hole_reacher.py @@ -45,6 +45,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): elif rew_fct == "vel_acc": from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty) + elif rew_fct == "unbounded": + from alr_envs.alr.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward + self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision) else: raise ValueError("Unknown reward function {}".format(rew_fct)) diff --git a/alr_envs/alr/classic_control/hole_reacher/hr_unbounded_reward.py b/alr_envs/alr/classic_control/hole_reacher/hr_unbounded_reward.py new file mode 100644 index 0000000..7ed13a1 --- /dev/null +++ b/alr_envs/alr/classic_control/hole_reacher/hr_unbounded_reward.py @@ -0,0 +1,60 @@ +import numpy as np + + +class HolereacherReward: + def __init__(self, allow_self_collision, allow_wall_collision): + + # collision + self.allow_self_collision = allow_self_collision + self.allow_wall_collision = allow_wall_collision + self._is_collided = False + + self.reward_factors = np.array((1, -5e-6)) + + def reset(self): + self._is_collided = False + + def get_reward(self, env): + dist_reward = 0 + success = False + + self_collision = False + wall_collision = False + + if not self.allow_self_collision: + self_collision = env._check_self_collision() + + if not self.allow_wall_collision: + wall_collision = env.check_wall_collision() + + self._is_collided = self_collision or wall_collision + + if env._steps == 180 or self._is_collided: + self.end_eff_pos = np.copy(env.end_effector) + + if env._steps == 199 or self._is_collided: + # return reward only in last time step + # Episode also terminates when colliding, hence return reward + dist = np.linalg.norm(self.end_eff_pos - env._goal) + + if self._is_collided: + dist_reward = 0.25 * np.exp(- dist) + else: + if env.end_effector[1] > 0: + dist_reward = np.exp(- dist) + else: + dist_reward = 1 - self.end_eff_pos[1] + + success = not self._is_collided + + info = {"is_success": success, + "is_collided": self._is_collided, + "end_effector": np.copy(env.end_effector), + "joints": np.copy(env.current_pos)} + + acc_cost = np.sum(env._acc ** 2) + + reward_features = np.array((dist_reward, acc_cost)) + reward = np.dot(reward_features, self.reward_factors) + + return reward, info \ No newline at end of file