Added ALRReacherProMP

This commit is contained in:
Fabian 2022-01-25 15:23:57 +01:00
parent 3f3bb98e84
commit 1f5a7b67f5
5 changed files with 142 additions and 26 deletions

View File

@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
from .mujoco.reacher.alr_reacher import ALRReacherEnv
from .mujoco.reacher.balancing import BalancingEnv
from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
from .mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
@ -364,6 +364,58 @@ for _v in _versions:
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## ALRReacher
_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"]
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}DMP-{_name[1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:{_v}",
"wrappers": [mujoco.reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5 if "long" not in _v.lower() else 7,
"num_basis": 5,
"duration": 4,
"alpha_phase": 2,
"learn_goal": True,
"policy_type": "motor",
"weights_scale": 1,
"policy_kwargs": {
"p_gains": 1,
"d_gains": 0.1
}
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": f"alr_envs:{_v}",
"wrappers": [mujoco.reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5 if "long" not in _v.lower() else 7,
"num_basis": 5,
"duration": 4,
"policy_type": "motor",
"weights_scale": 1,
"zero_start": True,
"policy_kwargs": {
"p_gains": 1,
"d_gains": 0.1
}
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## Beerpong
_versions = ["v0", "v1", "v2", "v3"]
for _v in _versions:

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -42,7 +42,10 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
if self._steps >= self.steps_before_reward:
vec = self.get_body_com("fingertip") - self.get_body_com("target")
reward_dist -= self.reward_weight * np.linalg.norm(vec)
if self.steps_before_reward > 0:
# avoid giving this penalty for normal step based case
angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
# angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
reward_ctrl = - np.square(a).sum()
if self.balance:
@ -61,14 +64,29 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
def viewer_setup(self):
self.viewer.cam.trackbodyid = 0
# def reset_model(self):
# qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
# while True:
# self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
# if np.linalg.norm(self.goal) < self.n_links / 10:
# break
# qpos[-2:] = self.goal
# qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
# qvel[-2:] = 0
# self.set_state(qpos, qvel)
# self._steps = 0
#
# return self._get_obs()
def reset_model(self):
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
qpos = self.init_qpos
if not hasattr(self, "goal"):
while True:
self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
if np.linalg.norm(self.goal) < self.n_links / 10:
break
qpos[-2:] = self.goal
qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
qvel = self.init_qvel
qvel[-2:] = 0
self.set_state(qpos, qvel)
self._steps = 0

View File

@ -0,0 +1,31 @@
from typing import Union
import numpy as np
from mp_env_api import MPEnvWrapper
class MPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.concatenate([
[True] * self.n_links, # cos
[True] * self.n_links, # sin
[True] * 2, # goal position
[True] * self.n_links, # angular velocity
[True] * 3, # goal distance
# self.get_body_com("target"), # only return target to make problem harder
[False], # step
])
@property
def current_vel(self) -> Union[float, int, np.ndarray]:
return self.sim.data.qvel.flat[:self.n_links]
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
return self.sim.data.qpos.flat[:self.n_links]
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -2,36 +2,46 @@ import numpy as np
from matplotlib import pyplot as plt
from alr_envs import dmc, meta
from alr_envs.alr import mujoco
from alr_envs.utils.make_env_helpers import make_promp_env
def visualize(env):
t = env.t
pos_features = env.mp.basis_generator.basis(t)
plt.plot(t, pos_features)
plt.show()
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
SEED = 10
env_id = "ball_in_cup-catch"
wrappers = [dmc.ball_in_cup.MPWrapper]
SEED = 1
# env_id = "ball_in_cup-catch"
env_id = "ALRReacherSparse-v0"
wrappers = [mujoco.reacher.MPWrapper]
mp_kwargs = {
"num_dof": 2,
"num_basis": 10,
"duration": 2,
"width": 0.025,
"num_dof": 5,
"num_basis": 8,
"duration": 4,
"policy_type": "motor",
"weights_scale": 1,
"zero_start": True,
"policy_kwargs": {
"p_gains": 1,
"d_gains": 1
"d_gains": 0.1
}
}
kwargs = dict(time_limit=2, episode_length=100)
# kwargs = dict(time_limit=4, episode_length=200)
kwargs = {}
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
**kwargs)
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
# Plot difference between real trajectory and target MP trajectory
env.reset()
pos, vel = env.mp_rollout(env.action_space.sample())
w = env.action_space.sample() * 10
visualize(env)
pos, vel = env.mp_rollout(w)
base_shape = env.full_action_space.shape
actual_pos = np.zeros((len(pos), *base_shape))
@ -51,18 +61,22 @@ plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.title("Position")
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
p1 = plt.plot(actual_pos, c='C0', label="true")
# plt.plot(actual_pos_ball, label="true pos ball")
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
plt.legend()
# plt.legend()
handles, labels = plt.gca().get_legend_handles_labels()
from collections import OrderedDict
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
plt.subplot(132)
plt.title("Velocity")
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.plot(actual_vel, c='C0', label="true")
plt.plot(vel, c='C1', label="MP")
plt.xlabel("Episode steps")
plt.legend()
plt.subplot(133)
plt.title("Actions")