added simple reacher task
This commit is contained in:
parent
96d1f93bda
commit
31156cec4d
111
.gitignore
vendored
Normal file
111
.gitignore
vendored
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
# pycharm
|
||||||
|
.DS_Store
|
||||||
|
/.idea
|
||||||
|
|
||||||
|
#configs
|
||||||
|
/configs/db.cfg
|
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="jdk" jdkName="Python 3.7 (trustpo)" jdkType="Python SDK" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
@ -10,4 +10,4 @@
|
|||||||
- Install: go to "../reacher_5_links"
|
- Install: go to "../reacher_5_links"
|
||||||
``` pip install -e reacher_5_links ```
|
``` pip install -e reacher_5_links ```
|
||||||
- Use (see example.py):
|
- Use (see example.py):
|
||||||
``` env = gym.make('reacher:ReacherALREnv-v0')```
|
``` env = gym.make('reacher:ALRReacherEnv-v0')```
|
0
__init__.py
Normal file
0
__init__.py
Normal file
16
alr_envs/__init__.py
Normal file
16
alr_envs/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from gym.envs.registration import register
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ALRReacher-v0',
|
||||||
|
entry_point='alr_envs.mujoco:ALRReacherEnv',
|
||||||
|
max_episode_steps=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='SimpleReacher-v0',
|
||||||
|
entry_point='alr_envs.classic_control:SimpleReacherEnv',
|
||||||
|
max_episode_steps=200,
|
||||||
|
kwargs={
|
||||||
|
"n_links": 5,
|
||||||
|
}
|
||||||
|
)
|
1
alr_envs/classic_control/__init__.py
Normal file
1
alr_envs/classic_control/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.classic_control.simple_reacher import SimpleReacherEnv
|
166
alr_envs/classic_control/simple_reacher.py
Normal file
166
alr_envs/classic_control/simple_reacher.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces, utils
|
||||||
|
from gym.utils import seeding
|
||||||
|
|
||||||
|
import matplotlib as mpl
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
mpl.use('Qt5Agg') # or can use 'TkAgg', whatever you have/prefer
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleReacherEnv(gym.Env, utils.EzPickle):
|
||||||
|
"""
|
||||||
|
Simple Reaching Task without any physics simulation.
|
||||||
|
Returns no reward until 150 time steps. This allows the agent to explore the space, but requires precise actions
|
||||||
|
towards the end of the trajectory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_links):
|
||||||
|
super().__init__()
|
||||||
|
self.link_lengths = np.ones(n_links)
|
||||||
|
self.n_links = n_links
|
||||||
|
self.dt = 0.1
|
||||||
|
|
||||||
|
self._goal_pos = None
|
||||||
|
|
||||||
|
self.joints = None
|
||||||
|
self._joint_angle = None
|
||||||
|
self._angle_velocity = None
|
||||||
|
|
||||||
|
self.max_torque = 1 # 10
|
||||||
|
|
||||||
|
action_bound = np.ones((self.n_links,))
|
||||||
|
state_bound = np.hstack([
|
||||||
|
[np.pi] * self.n_links,
|
||||||
|
[np.inf] * self.n_links,
|
||||||
|
[np.inf],
|
||||||
|
[np.inf] # TODO: Maybe
|
||||||
|
])
|
||||||
|
self.action_space = spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape)
|
||||||
|
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||||
|
|
||||||
|
self.fig = None
|
||||||
|
self.metadata = {'render.modes': ["human"]}
|
||||||
|
|
||||||
|
self._steps = 0
|
||||||
|
self.seed()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
|
||||||
|
action = self._scale_action(action)
|
||||||
|
|
||||||
|
self._angle_velocity = self._angle_velocity + self.dt * action
|
||||||
|
self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity)
|
||||||
|
self._update_joints()
|
||||||
|
self._steps += 1
|
||||||
|
|
||||||
|
reward = self._get_reward(action)
|
||||||
|
|
||||||
|
# done = np.abs(self.end_effector - self._goal_pos) < 0.1
|
||||||
|
done = False
|
||||||
|
|
||||||
|
return self._get_obs().copy(), reward, done, {}
|
||||||
|
|
||||||
|
def _scale_action(self, action):
|
||||||
|
"""
|
||||||
|
scale actions back in order to provide normalized actions \in [0,1]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: action to scale
|
||||||
|
|
||||||
|
Returns: action according to self.max_torque
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
ub = self.max_torque
|
||||||
|
lb = -self.max_torque
|
||||||
|
|
||||||
|
action = lb + (action + 1.) * 0.5 * (ub - lb)
|
||||||
|
return np.clip(action, lb, ub)
|
||||||
|
|
||||||
|
def _get_obs(self):
|
||||||
|
return [self._joint_angle, self._angle_velocity, self.end_effector - self._goal_pos, self._steps]
|
||||||
|
|
||||||
|
def _update_joints(self):
|
||||||
|
"""
|
||||||
|
update joints to get new end effector position. The other links are only required for rendering.
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
angles = np.cumsum(self._joint_angle)
|
||||||
|
x = self.link_lengths * np.vstack([np.cos(angles), np.sin(angles)])
|
||||||
|
self.joints[1:] = self.joints[0] + np.cumsum(x.T, axis=0)
|
||||||
|
|
||||||
|
def _get_reward(self, action):
|
||||||
|
diff = self.end_effector - self._goal_pos
|
||||||
|
distance = 0
|
||||||
|
|
||||||
|
# TODO: Is this the best option
|
||||||
|
if self._steps > 150:
|
||||||
|
distance = np.exp(-0.1 * diff ** 2).mean()
|
||||||
|
# distance -= (diff ** 2).mean()
|
||||||
|
|
||||||
|
# distance -= action ** 2
|
||||||
|
return distance
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
|
||||||
|
# TODO: maybe do initialisation more random?
|
||||||
|
# Sample only orientation of first link, i.e. the arm is always straight.
|
||||||
|
self._joint_angle = np.hstack([[self.np_random.uniform(-np.pi, np.pi)], np.zeros(self.n_links - 1)])
|
||||||
|
self._angle_velocity = np.zeros(self.n_links)
|
||||||
|
self.joints = np.zeros((self.n_links + 1, 2))
|
||||||
|
self._update_joints()
|
||||||
|
|
||||||
|
self._goal_pos = self._get_random_goal()
|
||||||
|
return self._get_obs().copy()
|
||||||
|
|
||||||
|
def _get_random_goal(self):
|
||||||
|
center = self.joints[0]
|
||||||
|
|
||||||
|
# Sample uniformly in circle with radius R around center of reacher.
|
||||||
|
R = np.sum(self.link_lengths)
|
||||||
|
r = R * np.sqrt(self.np_random.uniform())
|
||||||
|
theta = self.np_random.uniform() * 2 * np.pi
|
||||||
|
return center + r * np.stack([np.cos(theta), np.sin(theta)])
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
|
return [seed]
|
||||||
|
|
||||||
|
def render(self, mode='human'): # pragma: no cover
|
||||||
|
if self.fig is None:
|
||||||
|
self.fig = plt.figure()
|
||||||
|
plt.ion()
|
||||||
|
plt.show()
|
||||||
|
else:
|
||||||
|
plt.figure(self.fig.number)
|
||||||
|
|
||||||
|
plt.cla()
|
||||||
|
|
||||||
|
# Arm
|
||||||
|
plt.plot(self.joints[:, 0], self.joints[:, 1], 'ro-', markerfacecolor='k')
|
||||||
|
|
||||||
|
# goal
|
||||||
|
goal_pos = self._goal_pos.T
|
||||||
|
plt.plot(goal_pos[0], goal_pos[1], 'gx')
|
||||||
|
# distance between end effector and goal
|
||||||
|
plt.plot([self.end_effector[0], goal_pos[0]], [self.end_effector[1], goal_pos[1]], 'g--')
|
||||||
|
|
||||||
|
lim = np.sum(self.link_lengths) + 0.5
|
||||||
|
plt.xlim([-lim, lim])
|
||||||
|
plt.ylim([-lim, lim])
|
||||||
|
plt.draw()
|
||||||
|
plt.pause(0.0001)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
del self.fig
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_effector(self):
|
||||||
|
return self.joints[self.n_links].T
|
||||||
|
|
||||||
|
|
||||||
|
def angle_normalize(x):
|
||||||
|
return ((x + np.pi) % (2 * np.pi)) - np.pi
|
1
alr_envs/mujoco/__init__.py
Normal file
1
alr_envs/mujoco/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.mujoco.alr_reacher import ALRReacherEnv
|
@ -1,14 +1,16 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import mujoco_env
|
from gym.envs.mujoco import mujoco_env
|
||||||
|
|
||||||
class ReacherALREnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|
||||||
|
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
utils.EzPickle.__init__(self)
|
utils.EzPickle.__init__(self)
|
||||||
mujoco_env.MujocoEnv.__init__(self, '/home/vien/git/reacher_test/reacher/envs/reacher_5links.xml', 2)
|
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", 'reacher_5links.xml'), 2)
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
vec = self.get_body_com("fingertip")-self.get_body_com("target")
|
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||||
reward_dist = - np.linalg.norm(vec)
|
reward_dist = - np.linalg.norm(vec)
|
||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - np.square(a).sum()
|
||||||
reward = reward_dist + reward_ctrl
|
reward = reward_dist + reward_ctrl
|
16
example.py
16
example.py
@ -1,13 +1,15 @@
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# env = gym.make('alr_envs:ALRReacher-v0')
|
||||||
env = gym.make('reacher:ReacherALREnv-v0')
|
env = gym.make('alr_envs:SimpleReacher-v0')
|
||||||
#env = gym.make('Hopper-v2')
|
state = env.reset()
|
||||||
env.reset()
|
|
||||||
|
|
||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
action = env.action_space.sample()
|
state, reward, done, info = env.step(env.action_space.sample())
|
||||||
obs = env.step(action)
|
if i % 5 == 0:
|
||||||
print("step",i)
|
|
||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
|
if done:
|
||||||
|
state = env.reset()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
README.md
|
||||||
setup.py
|
setup.py
|
||||||
reacher.egg-info/PKG-INFO
|
reacher.egg-info/PKG-INFO
|
||||||
reacher.egg-info/SOURCES.txt
|
reacher.egg-info/SOURCES.txt
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
from gym.envs.registration import register
|
|
||||||
|
|
||||||
register(
|
|
||||||
id='ReacherALREnv-v0',
|
|
||||||
entry_point='reacher.envs:ReacherALREnv',
|
|
||||||
)
|
|
Binary file not shown.
@ -1 +0,0 @@
|
|||||||
from reacher.envs.reacher_env import ReacherALREnv
|
|
Binary file not shown.
Binary file not shown.
2
setup.py
2
setup.py
@ -1,6 +1,6 @@
|
|||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
setup(name='reacher',
|
setup(name='alr_envs',
|
||||||
version='0.0.1',
|
version='0.0.1',
|
||||||
install_requires=['gym'] # And any other dependencies foo needs
|
install_requires=['gym'] # And any other dependencies foo needs
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user