From 6a7c6991bb518a6f2c89a4ef5c39b152418de515 Mon Sep 17 00:00:00 2001 From: ottofabian Date: Fri, 26 Mar 2021 15:32:50 +0100 Subject: [PATCH] added balancing task --- alr_envs/__init__.py | 9 ++++++ alr_envs/mujoco/__init__.py | 3 +- alr_envs/mujoco/balancing.py | 53 ++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 alr_envs/mujoco/balancing.py diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 62e5db0..2e50f80 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -93,6 +93,15 @@ register( } ) +register( + id='Balancing-v0', + entry_point='alr_envs.mujoco:BalancingEnv', + max_episode_steps=200, + kwargs={ + "n_links": 5, + } +) + register( id='SimpleReacher-v0', entry_point='alr_envs.classic_control:SimpleReacherEnv', diff --git a/alr_envs/mujoco/__init__.py b/alr_envs/mujoco/__init__.py index 77588f7..b149793 100644 --- a/alr_envs/mujoco/__init__.py +++ b/alr_envs/mujoco/__init__.py @@ -1 +1,2 @@ -from alr_envs.mujoco.alr_reacher import ALRReacherEnv \ No newline at end of file +from alr_envs.mujoco.alr_reacher import ALRReacherEnv +from alr_envs.mujoco.balancing import BalancingEnv diff --git a/alr_envs/mujoco/balancing.py b/alr_envs/mujoco/balancing.py new file mode 100644 index 0000000..5976bc2 --- /dev/null +++ b/alr_envs/mujoco/balancing.py @@ -0,0 +1,53 @@ +import os + +import numpy as np +from gym import utils +from gym.envs.mujoco import mujoco_env + +from alr_envs.utils.utils import angle_normalize + + +class BalancingEnv(mujoco_env.MujocoEnv, utils.EzPickle): + def __init__(self, n_links=5): + utils.EzPickle.__init__(**locals()) + + self.n_links = n_links + + if n_links == 5: + file_name = 'reacher_5links.xml' + elif n_links == 7: + file_name = 'reacher_7links.xml' + else: + raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.") + + mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2) + + def step(self, a): + angle = angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad") + reward = - np.abs(angle) + + self.do_simulation(a, self.frame_skip) + ob = self._get_obs() + done = False + return ob, reward, done, dict(angle=angle, end_effector=self.get_body_com("fingertip").copy()) + + def viewer_setup(self): + self.viewer.cam.trackbodyid = 1 + + def reset_model(self): + # This also generates a goal, we however do not need/use it + qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos + qpos[-2:] = 0 + qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) + qvel[-2:] = 0 + self.set_state(qpos, qvel) + + return self._get_obs() + + def _get_obs(self): + theta = self.sim.data.qpos.flat[:self.n_links] + return np.concatenate([ + np.cos(theta), + np.sin(theta), + self.sim.data.qvel.flat[:self.n_links], # this is angular velocity + ])