Added ALRReacherProMP
This commit is contained in:
parent
3f3bb98e84
commit
1f5a7b67f5
@ -10,7 +10,7 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
|
|||||||
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
from .mujoco.reacher.alr_reacher import ALRReacherEnv
|
||||||
from .mujoco.reacher.balancing import BalancingEnv
|
from .mujoco.reacher.balancing import BalancingEnv
|
||||||
|
|
||||||
from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
|
from .mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
|
||||||
|
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
@ -364,6 +364,58 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
## ALRReacher
|
||||||
|
_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"]
|
||||||
|
for _v in _versions:
|
||||||
|
_name = _v.split("-")
|
||||||
|
_env_id = f'{_name[0]}DMP-{_name[1]}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": f"alr_envs:{_v}",
|
||||||
|
"wrappers": [mujoco.reacher.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 4,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"learn_goal": True,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 1,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 1,
|
||||||
|
"d_gains": 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
|
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
|
register(
|
||||||
|
id=_env_id,
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"alr_envs:{_v}",
|
||||||
|
"wrappers": [mujoco.reacher.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 4,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 1,
|
||||||
|
"zero_start": True,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": 1,
|
||||||
|
"d_gains": 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
## Beerpong
|
## Beerpong
|
||||||
_versions = ["v0", "v1", "v2", "v3"]
|
_versions = ["v0", "v1", "v2", "v3"]
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from .mp_wrapper import MPWrapper
|
@ -42,7 +42,10 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
if self._steps >= self.steps_before_reward:
|
if self._steps >= self.steps_before_reward:
|
||||||
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||||
reward_dist -= self.reward_weight * np.linalg.norm(vec)
|
reward_dist -= self.reward_weight * np.linalg.norm(vec)
|
||||||
angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
|
if self.steps_before_reward > 0:
|
||||||
|
# avoid giving this penalty for normal step based case
|
||||||
|
angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
|
||||||
|
# angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
|
||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - np.square(a).sum()
|
||||||
|
|
||||||
if self.balance:
|
if self.balance:
|
||||||
@ -61,14 +64,29 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
|
# def reset_model(self):
|
||||||
|
# qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
||||||
|
# while True:
|
||||||
|
# self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
|
||||||
|
# if np.linalg.norm(self.goal) < self.n_links / 10:
|
||||||
|
# break
|
||||||
|
# qpos[-2:] = self.goal
|
||||||
|
# qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
|
||||||
|
# qvel[-2:] = 0
|
||||||
|
# self.set_state(qpos, qvel)
|
||||||
|
# self._steps = 0
|
||||||
|
#
|
||||||
|
# return self._get_obs()
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
qpos = self.init_qpos
|
||||||
while True:
|
if not hasattr(self, "goal"):
|
||||||
self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
|
while True:
|
||||||
if np.linalg.norm(self.goal) < self.n_links / 10:
|
self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
|
||||||
break
|
if np.linalg.norm(self.goal) < self.n_links / 10:
|
||||||
|
break
|
||||||
qpos[-2:] = self.goal
|
qpos[-2:] = self.goal
|
||||||
qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
|
qvel = self.init_qvel
|
||||||
qvel[-2:] = 0
|
qvel[-2:] = 0
|
||||||
self.set_state(qpos, qvel)
|
self.set_state(qpos, qvel)
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
31
alr_envs/alr/mujoco/reacher/mp_wrapper.py
Normal file
31
alr_envs/alr/mujoco/reacher/mp_wrapper.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from mp_env_api import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.concatenate([
|
||||||
|
[True] * self.n_links, # cos
|
||||||
|
[True] * self.n_links, # sin
|
||||||
|
[True] * 2, # goal position
|
||||||
|
[True] * self.n_links, # angular velocity
|
||||||
|
[True] * 3, # goal distance
|
||||||
|
# self.get_body_com("target"), # only return target to make problem harder
|
||||||
|
[False], # step
|
||||||
|
])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self.sim.data.qvel.flat[:self.n_links]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self.sim.data.qpos.flat[:self.n_links]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self) -> Union[float, int]:
|
||||||
|
return self.env.dt
|
@ -2,36 +2,46 @@ import numpy as np
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from alr_envs import dmc, meta
|
from alr_envs import dmc, meta
|
||||||
|
from alr_envs.alr import mujoco
|
||||||
from alr_envs.utils.make_env_helpers import make_promp_env
|
from alr_envs.utils.make_env_helpers import make_promp_env
|
||||||
|
|
||||||
|
|
||||||
|
def visualize(env):
|
||||||
|
t = env.t
|
||||||
|
pos_features = env.mp.basis_generator.basis(t)
|
||||||
|
plt.plot(t, pos_features)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
SEED = 10
|
SEED = 1
|
||||||
env_id = "ball_in_cup-catch"
|
# env_id = "ball_in_cup-catch"
|
||||||
wrappers = [dmc.ball_in_cup.MPWrapper]
|
env_id = "ALRReacherSparse-v0"
|
||||||
|
wrappers = [mujoco.reacher.MPWrapper]
|
||||||
|
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 2,
|
"num_dof": 5,
|
||||||
"num_basis": 10,
|
"num_basis": 8,
|
||||||
"duration": 2,
|
"duration": 4,
|
||||||
"width": 0.025,
|
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1,
|
"weights_scale": 1,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 1,
|
"p_gains": 1,
|
||||||
"d_gains": 1
|
"d_gains": 0.1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kwargs = dict(time_limit=2, episode_length=100)
|
# kwargs = dict(time_limit=4, episode_length=200)
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
# Plot difference between real trajectory and target MP trajectory
|
# Plot difference between real trajectory and target MP trajectory
|
||||||
env.reset()
|
env.reset()
|
||||||
pos, vel = env.mp_rollout(env.action_space.sample())
|
w = env.action_space.sample() * 10
|
||||||
|
visualize(env)
|
||||||
|
pos, vel = env.mp_rollout(w)
|
||||||
|
|
||||||
base_shape = env.full_action_space.shape
|
base_shape = env.full_action_space.shape
|
||||||
actual_pos = np.zeros((len(pos), *base_shape))
|
actual_pos = np.zeros((len(pos), *base_shape))
|
||||||
@ -51,18 +61,22 @@ plt.figure(figsize=(15, 5))
|
|||||||
|
|
||||||
plt.subplot(131)
|
plt.subplot(131)
|
||||||
plt.title("Position")
|
plt.title("Position")
|
||||||
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
p1 = plt.plot(actual_pos, c='C0', label="true")
|
||||||
# plt.plot(actual_pos_ball, label="true pos ball")
|
# plt.plot(actual_pos_ball, label="true pos ball")
|
||||||
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))])
|
||||||
plt.xlabel("Episode steps")
|
plt.xlabel("Episode steps")
|
||||||
plt.legend()
|
# plt.legend()
|
||||||
|
handles, labels = plt.gca().get_legend_handles_labels()
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
by_label = OrderedDict(zip(labels, handles))
|
||||||
|
plt.legend(by_label.values(), by_label.keys())
|
||||||
|
|
||||||
plt.subplot(132)
|
plt.subplot(132)
|
||||||
plt.title("Velocity")
|
plt.title("Velocity")
|
||||||
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
plt.plot(actual_vel, c='C0', label="true")
|
||||||
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
plt.plot(vel, c='C1', label="MP")
|
||||||
plt.xlabel("Episode steps")
|
plt.xlabel("Episode steps")
|
||||||
plt.legend()
|
|
||||||
|
|
||||||
plt.subplot(133)
|
plt.subplot(133)
|
||||||
plt.title("Actions")
|
plt.title("Actions")
|
||||||
|
Loading…
Reference in New Issue
Block a user