diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 90ec78c..53f292c 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv from .mujoco.reacher.alr_reacher import ALRReacherEnv from .mujoco.reacher.balancing import BalancingEnv -from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS +from .mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} @@ -363,6 +363,58 @@ for _v in _versions: } ) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) + +## ALRReacher +_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"] +for _v in _versions: + _name = _v.split("-") + _env_id = f'{_name[0]}DMP-{_name[1]}' + register( + id=_env_id, + entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', + # max_episode_steps=1, + kwargs={ + "name": f"alr_envs:{_v}", + "wrappers": [mujoco.reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5 if "long" not in _v.lower() else 7, + "num_basis": 5, + "duration": 4, + "alpha_phase": 2, + "learn_goal": True, + "policy_type": "motor", + "weights_scale": 1, + "policy_kwargs": { + "p_gains": 1, + "d_gains": 0.1 + } + } + } + ) + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) + + _env_id = f'{_name[0]}ProMP-{_name[1]}' + register( + id=_env_id, + entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', + kwargs={ + "name": f"alr_envs:{_v}", + "wrappers": [mujoco.reacher.MPWrapper], + "mp_kwargs": { + "num_dof": 5 if "long" not in _v.lower() else 7, + "num_basis": 5, + "duration": 4, + "policy_type": "motor", + "weights_scale": 1, + "zero_start": True, + "policy_kwargs": { + "p_gains": 1, + "d_gains": 0.1 + } + } + } + ) + ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ## Beerpong _versions = ["v0", "v1", "v2", "v3"] diff --git a/alr_envs/alr/mujoco/reacher/__init__.py b/alr_envs/alr/mujoco/reacher/__init__.py index e69de29..989b5a9 100644 --- a/alr_envs/alr/mujoco/reacher/__init__.py +++ b/alr_envs/alr/mujoco/reacher/__init__.py @@ -0,0 +1 @@ +from .mp_wrapper import MPWrapper \ No newline at end of file diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index 2d122d2..e21eaed 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -42,7 +42,10 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): if self._steps >= self.steps_before_reward: vec = self.get_body_com("fingertip") - self.get_body_com("target") reward_dist -= self.reward_weight * np.linalg.norm(vec) - angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) + if self.steps_before_reward > 0: + # avoid giving this penalty for normal step based case + angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links]) + # angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum() reward_ctrl = - np.square(a).sum() if self.balance: @@ -61,14 +64,29 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle): def viewer_setup(self): self.viewer.cam.trackbodyid = 0 + # def reset_model(self): + # qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos + # while True: + # self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + # if np.linalg.norm(self.goal) < self.n_links / 10: + # break + # qpos[-2:] = self.goal + # qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + # qvel[-2:] = 0 + # self.set_state(qpos, qvel) + # self._steps = 0 + # + # return self._get_obs() + def reset_model(self): - qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos - while True: - self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) - if np.linalg.norm(self.goal) < self.n_links / 10: - break + qpos = self.init_qpos + if not hasattr(self, "goal"): + while True: + self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2) + if np.linalg.norm(self.goal) < self.n_links / 10: + break qpos[-2:] = self.goal - qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + qvel = self.init_qvel qvel[-2:] = 0 self.set_state(qpos, qvel) self._steps = 0 diff --git a/alr_envs/alr/mujoco/reacher/mp_wrapper.py b/alr_envs/alr/mujoco/reacher/mp_wrapper.py new file mode 100644 index 0000000..027a2e2 --- /dev/null +++ b/alr_envs/alr/mujoco/reacher/mp_wrapper.py @@ -0,0 +1,31 @@ +from typing import Union + +import numpy as np +from mp_env_api import MPEnvWrapper + + +class MPWrapper(MPEnvWrapper): + + @property + def active_obs(self): + return np.concatenate([ + [True] * self.n_links, # cos + [True] * self.n_links, # sin + [True] * 2, # goal position + [True] * self.n_links, # angular velocity + [True] * 3, # goal distance + # self.get_body_com("target"), # only return target to make problem harder + [False], # step + ]) + + @property + def current_vel(self) -> Union[float, int, np.ndarray]: + return self.sim.data.qvel.flat[:self.n_links] + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.sim.data.qpos.flat[:self.n_links] + + @property + def dt(self) -> Union[float, int]: + return self.env.dt diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index 90aac11..bdcaa41 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -2,36 +2,46 @@ import numpy as np from matplotlib import pyplot as plt from alr_envs import dmc, meta +from alr_envs.alr import mujoco from alr_envs.utils.make_env_helpers import make_promp_env + +def visualize(env): + t = env.t + pos_features = env.mp.basis_generator.basis(t) + plt.plot(t, pos_features) + plt.show() + # This might work for some environments, however, please verify either way the correct trajectory information # for your environment are extracted below -SEED = 10 -env_id = "ball_in_cup-catch" -wrappers = [dmc.ball_in_cup.MPWrapper] +SEED = 1 +# env_id = "ball_in_cup-catch" +env_id = "ALRReacherSparse-v0" +wrappers = [mujoco.reacher.MPWrapper] mp_kwargs = { - "num_dof": 2, - "num_basis": 10, - "duration": 2, - "width": 0.025, + "num_dof": 5, + "num_basis": 8, + "duration": 4, "policy_type": "motor", "weights_scale": 1, "zero_start": True, "policy_kwargs": { "p_gains": 1, - "d_gains": 1 + "d_gains": 0.1 } } -kwargs = dict(time_limit=2, episode_length=100) +# kwargs = dict(time_limit=4, episode_length=200) +kwargs = {} -env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, - **kwargs) +env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs) # Plot difference between real trajectory and target MP trajectory env.reset() -pos, vel = env.mp_rollout(env.action_space.sample()) +w = env.action_space.sample() * 10 +visualize(env) +pos, vel = env.mp_rollout(w) base_shape = env.full_action_space.shape actual_pos = np.zeros((len(pos), *base_shape)) @@ -51,18 +61,22 @@ plt.figure(figsize=(15, 5)) plt.subplot(131) plt.title("Position") -plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))]) +p1 = plt.plot(actual_pos, c='C0', label="true") # plt.plot(actual_pos_ball, label="true pos ball") -plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))]) plt.xlabel("Episode steps") -plt.legend() +# plt.legend() +handles, labels = plt.gca().get_legend_handles_labels() +from collections import OrderedDict + +by_label = OrderedDict(zip(labels, handles)) +plt.legend(by_label.values(), by_label.keys()) plt.subplot(132) plt.title("Velocity") -plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))]) -plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))]) +plt.plot(actual_vel, c='C0', label="true") +plt.plot(vel, c='C1', label="MP") plt.xlabel("Episode steps") -plt.legend() plt.subplot(133) plt.title("Actions")