Merge pull request #18 from ALRhub/deprecate_detpmp

Deprecate detpmp
This commit is contained in:
ottofabian 2022-01-11 15:57:33 +01:00 committed by GitHub
commit 04d27426ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 271 additions and 522 deletions

View File

@ -81,7 +81,7 @@ trajectory.
```python ```python
import alr_envs import alr_envs
env = alr_envs.make('HoleReacherDetPMP-v0', seed=1) env = alr_envs.make('HoleReacherProMP-v0', seed=1)
# render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None). # render() can be called once in the beginning with all necessary arguments. To turn it of again just call render(None).
env.render() env.render()
@ -95,7 +95,7 @@ for i in range(5):
``` ```
To show all available environments, we provide some additional convenience. Each value will return a dictionary with two To show all available environments, we provide some additional convenience. Each value will return a dictionary with two
keys `DMP` and `DetPMP` that store a list of available environment names. keys `DMP` and `ProMP` that store a list of available environment names.
```python ```python
import alr_envs import alr_envs
@ -193,7 +193,7 @@ mp_kwargs = {...}
kwargs = {...} kwargs = {...}
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs) env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required): # OR for a deterministic ProMP (other mp_kwargs are required):
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
rewards = 0 rewards = 0
obs = env.reset() obs = env.reset()

View File

@ -1,5 +1,5 @@
from alr_envs import dmc, meta, open_ai from alr_envs import dmc, meta, open_ai
from alr_envs.utils.make_env_helpers import make, make_detpmp_env, make_dmp_env, make_promp_env, make_rank from alr_envs.utils.make_env_helpers import make, make_dmp_env, make_promp_env, make_rank
from alr_envs.utils import make_dmc from alr_envs.utils import make_dmc
# Convenience function for all MP environments # Convenience function for all MP environments

View File

@ -10,7 +10,9 @@ from .mujoco.ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
from .mujoco.reacher.alr_reacher import ALRReacherEnv from .mujoco.reacher.alr_reacher import ALRReacherEnv
from .mujoco.reacher.balancing import BalancingEnv from .mujoco.reacher.balancing import BalancingEnv
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# Classic Control # Classic Control
## Simple Reacher ## Simple Reacher
@ -195,16 +197,20 @@ register(
) )
## Table Tennis ## Table Tennis
from alr_envs.alr.mujoco.table_tennis.tt_gym import MAX_EPISODE_STEPS
register(id='TableTennis2DCtxt-v0', register(id='TableTennis2DCtxt-v0',
entry_point='alr_envs.alr.mujoco:TT_Env_Gym', entry_point='alr_envs.alr.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS, max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim':2}) kwargs={'ctxt_dim': 2})
register(id='TableTennis2DCtxt-v1',
entry_point='alr_envs.alr.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim': 2, 'fixed_goal': True})
register(id='TableTennis4DCtxt-v0', register(id='TableTennis4DCtxt-v0',
entry_point='alr_envs.alr.mujoco:TT_Env_Gym', entry_point='alr_envs.alr.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS, max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim':4}) kwargs={'ctxt_dim': 4})
## BeerPong ## BeerPong
difficulties = ["simple", "intermediate", "hard", "hardest"] difficulties = ["simple", "intermediate", "hard", "hardest"]
@ -240,8 +246,12 @@ for _v in _versions:
"duration": 2, "duration": 2,
"alpha_phase": 2, "alpha_phase": 2,
"learn_goal": True, "learn_goal": True,
"policy_type": "velocity", "policy_type": "motor",
"weights_scale": 50, "weights_scale": 50,
"policy_kwargs": {
"p_gains": .6,
"d_gains": .075
}
} }
} }
) )
@ -260,33 +270,16 @@ for _v in _versions:
"duration": 2, "duration": 2,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 1, "weights_scale": 1,
"zero_start": True "zero_start": True,
"policy_kwargs": {
"p_gains": .6,
"d_gains": .075
}
} }
} }
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'{_name[0]}DetPMP-{_name[1]}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:{_v}",
"wrappers": [classic_control.simple_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 2 if "long" not in _v.lower() else 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
# Viapoint reacher # Viapoint reacher
register( register(
id='ViaPointReacherDMP-v0', id='ViaPointReacherDMP-v0',
@ -318,7 +311,7 @@ register(
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"policy_type": "motor", "policy_type": "velocity",
"weights_scale": 1, "weights_scale": 1,
"zero_start": True "zero_start": True
} }
@ -326,26 +319,6 @@ register(
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0") ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
register(
id='ViaPointReacherDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:ViaPointReacher-v0",
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ViaPointReacherDetPMP-v0")
## Hole Reacher ## Hole Reacher
_versions = ["v0", "v1", "v2"] _versions = ["v0", "v1", "v2"]
for _v in _versions: for _v in _versions:
@ -391,43 +364,23 @@ for _v in _versions:
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'HoleReacherDetPMP-{_v}' ## Beerpong
_versions = ["v0", "v1", "v2", "v3"]
for _v in _versions:
_env_id = f'BeerpongProMP-{_v}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{_v}", "name": f"alr_envs:ALRBeerPong-{_v}",
"wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id)
## Beerpong
register(
id='BeerpongDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": "alr_envs:ALRBeerPong-v0",
"wrappers": [mujoco.beerpong.MPWrapper], "wrappers": [mujoco.beerpong.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 7, "num_dof": 7,
"num_basis": 2, "num_basis": 2,
"n_zero_bases": 2, "duration": 1,
"duration": 0.5, "post_traj_time": 2,
"post_traj_time": 2.5,
"width": 0.01,
"off": 0.01,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.08, "weights_scale": 1,
"zero_start": True, "zero_start": True,
"zero_goal": False, "zero_goal": False,
"policy_kwargs": { "policy_kwargs": {
@ -436,24 +389,24 @@ register(
} }
} }
} }
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("BeerpongDetPMP-v0") ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## Table Tennis ## Table Tennis
register( ctxt_dim = [2, 4]
id='TableTennisDetPMP-v0', for _v, cd in enumerate(ctxt_dim):
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', _env_id = f'TableTennisProMP-v{_v}'
register(
id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:TableTennis4DCtxt-v0", "name": "alr_envs:TableTennis{}DCtxt-v0".format(cd),
"wrappers": [mujoco.table_tennis.MPWrapper], "wrappers": [mujoco.table_tennis.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 7, "num_dof": 7,
"num_basis": 2, "num_basis": 2,
"n_zero_bases": 2,
"duration": 1.25, "duration": 1.25,
"post_traj_time": 4.5, "post_traj_time": 4.5,
"width": 0.01,
"off": 0.01,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 1.0, "weights_scale": 1.0,
"zero_start": True, "zero_start": True,
@ -464,5 +417,31 @@ register(
} }
} }
} }
)
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
register(
id='TableTennisProMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={
"name": "alr_envs:TableTennis2DCtxt-v1",
"wrappers": [mujoco.table_tennis.MPWrapper],
"mp_kwargs": {
"num_dof": 7,
"num_basis": 2,
"duration": 1.,
"post_traj_time": 2.5,
"policy_type": "motor",
"weights_scale": 1,
"off": -0.05,
"bandwidth_factor": 3.5,
"zero_start": True,
"zero_goal": False,
"policy_kwargs": {
"p_gains": 0.5*np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]),
"d_gains": 0.5*np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
}
}
}
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("TableTennisDetPMP-v0") ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v2")

View File

@ -14,8 +14,5 @@
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25 |`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25 |`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30 |`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupSimple-v0` task where only 3 joints are actuated. | 4000 | 15
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
[//]: |`HoleReacherDetPMP-v0`| [//]: |`HoleReacherProMPP-v0`|

View File

@ -5,7 +5,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from gym.utils import seeding from gym.utils import seeding
from alr_envs.alr.classic_control.utils import check_self_collision
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv

View File

@ -2,5 +2,5 @@ from .reacher.alr_reacher import ALRReacherEnv
from .reacher.balancing import BalancingEnv from .reacher.balancing import BalancingEnv
from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv
from .table_tennis.tt_gym import TT_Env_Gym from .table_tennis.tt_gym import TTEnvGym
from .beerpong.beerpong import ALRBeerBongEnv from .beerpong.beerpong import ALRBeerBongEnv

View File

@ -1,107 +0,0 @@
from alr_envs.alr.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
def make_contextual_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="contextual_goal")
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def _make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="simple")
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.2, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def make_simple_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBallInACupEnv(reward_type="simple")
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True, off=-0.1)
env.seed(seed + rank)
return env
return _init
def make_simple_dmp_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = ALRBallInACupEnv(reward_type="simple")
_env = DmpWrapper(_env,
num_dof=3,
num_basis=5,
duration=3.5,
post_traj_time=4.5,
bandwidth_factor=2.5,
dt=_env.dt,
learn_goal=False,
alpha_phase=3,
start_pos=_env.start_pos[1::2],
final_pos=_env.start_pos[1::2],
policy_type="motor",
weights_scale=100,
)
_env.seed(seed + rank)
return _env
return _init

View File

@ -27,10 +27,10 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
self.ball_site_id = 0 self.ball_site_id = 0
self.ball_id = 11 self.ball_id = 11
self._release_step = 100 # time step of ball release self._release_step = 175 # time step of ball release
self.sim_time = 4 # seconds self.sim_time = 3 # seconds
self.ep_length = 600 # based on 5 seconds with dt = 0.005 int(self.sim_time / self.dt) self.ep_length = 600 # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt)
self.cup_table_id = 10 self.cup_table_id = 10
if noisy: if noisy:
@ -127,13 +127,14 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
self._steps += 1 self._steps += 1
else: else:
reward = -30 reward = -30
reward_infos = dict()
success = False success = False
is_collided = False is_collided = False
done = True done = True
ball_pos = np.zeros(3) ball_pos = np.zeros(3)
ball_vel = np.zeros(3) ball_vel = np.zeros(3)
return ob, reward, done, dict(reward_dist=reward_dist, infos = dict(reward_dist=reward_dist,
reward_ctrl=reward_ctrl, reward_ctrl=reward_ctrl,
reward=reward, reward=reward,
velocity=angular_vel, velocity=angular_vel,
@ -143,8 +144,11 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
q_vel=self.sim.data.qvel[0:7].ravel().copy(), q_vel=self.sim.data.qvel[0:7].ravel().copy(),
ball_pos=ball_pos, ball_pos=ball_pos,
ball_vel=ball_vel, ball_vel=ball_vel,
is_success=success, success=success,
is_collided=is_collided, sim_crash=crash) is_collided=is_collided, sim_crash=crash)
infos.update(reward_infos)
return ob, reward, done, infos
def check_traj_in_joint_limits(self): def check_traj_in_joint_limits(self):
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min) return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
@ -171,7 +175,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
if __name__ == "__main__": if __name__ == "__main__":
env = ALRBeerBongEnv(reward_type="no_context", difficulty='hardest') env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest')
# env.configure(ctxt) # env.configure(ctxt)
env.reset() env.reset()

View File

@ -71,6 +71,7 @@ class BeerPongReward:
goal_pos = env.sim.data.site_xpos[self.goal_id] goal_pos = env.sim.data.site_xpos[self.goal_id]
ball_pos = env.sim.data.body_xpos[self.ball_id] ball_pos = env.sim.data.body_xpos[self.ball_id]
ball_vel = env.sim.data.body_xvelp[self.ball_id]
goal_final_pos = env.sim.data.site_xpos[self.goal_final_id] goal_final_pos = env.sim.data.site_xpos[self.goal_final_id]
self.dists.append(np.linalg.norm(goal_pos - ball_pos)) self.dists.append(np.linalg.norm(goal_pos - ball_pos))
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
@ -131,6 +132,7 @@ class BeerPongReward:
infos["success"] = success infos["success"] = success
infos["is_collided"] = self._is_collided infos["is_collided"] = self._is_collided
infos["ball_pos"] = ball_pos.copy() infos["ball_pos"] = ball_pos.copy()
infos["ball_vel"] = ball_vel.copy()
infos["action_cost"] = 5e-4 * action_cost infos["action_cost"] = 5e-4 * action_cost
return reward, infos return reward, infos

View File

@ -81,32 +81,36 @@ class BeerPongReward:
action_cost = np.sum(np.square(action)) action_cost = np.sum(np.square(action))
self.action_costs.append(action_cost) self.action_costs.append(action_cost)
if not self.ball_table_contact:
self.ball_table_contact = self._check_collision_single_objects(env.sim, self.ball_collision_id,
self.table_collision_id)
self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids) self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids)
if env._steps == env.ep_length - 1 or self._is_collided: if env._steps == env.ep_length - 1 or self._is_collided:
min_dist = np.min(self.dists) min_dist = np.min(self.dists)
ball_table_bounce = self._check_collision_single_objects(env.sim, self.ball_collision_id, final_dist = self.dists_final[-1]
self.table_collision_id)
ball_cup_table_cont = self._check_collision_with_set_of_objects(env.sim, self.ball_collision_id,
self.cup_collision_ids)
ball_wall_cont = self._check_collision_single_objects(env.sim, self.ball_collision_id,
self.wall_collision_id)
ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id, ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id,
self.cup_table_collision_id) self.cup_table_collision_id)
if not ball_in_cup:
cost_offset = 2
if not ball_cup_table_cont and not ball_table_bounce and not ball_wall_cont:
cost_offset += 2
cost = cost_offset + min_dist ** 2 + 0.5 * self.dists_final[-1] ** 2 + 1e-7 * action_cost
else:
cost = self.dists_final[-1] ** 2 + 1.5 * action_cost * 1e-7
reward = - 1 * cost - self.collision_penalty * int(self._is_collided) # encourage bounce before falling into cup
if not ball_in_cup:
if not self.ball_table_contact:
reward = 0.2 * (1 - np.tanh(min_dist ** 2)) + 0.1 * (1 - np.tanh(final_dist ** 2))
else:
reward = (1 - np.tanh(min_dist ** 2)) + 0.5 * (1 - np.tanh(final_dist ** 2))
else:
if not self.ball_table_contact:
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 1
else:
reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 3
# reward = - 1 * cost - self.collision_penalty * int(self._is_collided)
success = ball_in_cup success = ball_in_cup
crash = self._is_collided crash = self._is_collided
else: else:
reward = - 1e-7 * action_cost reward = - 1e-2 * action_cost
cost = 0
success = False success = False
crash = False crash = False
@ -115,26 +119,11 @@ class BeerPongReward:
infos["is_collided"] = self._is_collided infos["is_collided"] = self._is_collided
infos["ball_pos"] = ball_pos.copy() infos["ball_pos"] = ball_pos.copy()
infos["ball_vel"] = ball_vel.copy() infos["ball_vel"] = ball_vel.copy()
infos["action_cost"] = 5e-4 * action_cost infos["action_cost"] = action_cost
infos["task_cost"] = cost infos["task_reward"] = reward
return reward, infos return reward, infos
def get_cost_offset(self):
if self.ball_ground_contact:
return 200
if not self.ball_table_contact:
return 100
if not self.ball_in_cup:
return 50
if self.ball_in_cup and self.ball_cup_contact and not self.noisy_bp:
return 10
return 0
def _check_collision_single_objects(self, sim, id_1, id_2): def _check_collision_single_objects(self, sim, id_1, id_2):
for coni in range(0, sim.data.ncon): for coni in range(0, sim.data.ncon):
con = sim.data.contact[coni] con = sim.data.contact[coni]

View File

@ -6,8 +6,6 @@ from gym.envs.mujoco import MujocoEnv
class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None): def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None):
utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
@ -28,15 +26,13 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
self.context = None self.context = None
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
# alr_mujoco_env.AlrMujocoEnv.__init__(self, # alr_mujoco_env.AlrMujocoEnv.__init__(self,
# self.xml_path, # self.xml_path,
# apply_gravity_comp=apply_gravity_comp, # apply_gravity_comp=apply_gravity_comp,
# n_substeps=n_substeps) # n_substeps=n_substeps)
self.sim_time = 8 # seconds self.sim_time = 8 # seconds
self.sim_steps = int(self.sim_time / self.dt) # self.sim_steps = int(self.sim_time / self.dt)
if reward_function is None: if reward_function is None:
from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward
reward_function = BeerpongReward reward_function = BeerpongReward
@ -46,6 +42,9 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
self.cup_table_id = self.sim.model._body_name2id["cup_table"] self.cup_table_id = self.sim.model._body_name2id["cup_table"]
# self.bounce_table_id = self.sim.model._body_name2id["bounce_table"] # self.bounce_table_id = self.sim.model._body_name2id["bounce_table"]
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps)
utils.EzPickle.__init__(self)
@property @property
def current_pos(self): def current_pos(self):
return self.sim.data.qpos[0:7].copy() return self.sim.data.qpos[0:7].copy()
@ -90,7 +89,7 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle):
reward_ctrl = - np.square(a).sum() reward_ctrl = - np.square(a).sum()
action_cost = np.sum(np.square(a)) action_cost = np.sum(np.square(a))
crash = self.do_simulation(a) crash = self.do_simulation(a, self.frame_skip)
joint_cons_viol = self.check_traj_in_joint_limits() joint_cons_viol = self.check_traj_in_joint_limits()
self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy()) self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy())

View File

@ -1,72 +0,0 @@
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
from alr_envs.alr.mujoco.beerpong.beerpong import ALRBeerpongEnv
from alr_envs.alr.mujoco.beerpong.beerpong_simple import ALRBeerpongEnv as ALRBeerpongEnvSimple
def make_contextual_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnv()
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def _make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnvSimple()
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init
def make_simple_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
env = ALRBeerpongEnvSimple()
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
env.seed(seed + rank)
return env
return _init

View File

@ -10,7 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
#TODO: Check for simulation stability. Make sure the code runs even for sim crash #TODO: Check for simulation stability. Make sure the code runs even for sim crash
MAX_EPISODE_STEPS = 1375 MAX_EPISODE_STEPS = 1750
BALL_NAME_CONTACT = "target_ball_contact" BALL_NAME_CONTACT = "target_ball_contact"
BALL_NAME = "target_ball" BALL_NAME = "target_ball"
TABLE_NAME = 'table_tennis_table' TABLE_NAME = 'table_tennis_table'
@ -22,14 +22,19 @@ RACKET_NAME = 'bat'
CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]]) CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]])
CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]]) CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]])
class TT_Env_Gym(MujocoEnv, utils.EzPickle):
def __init__(self, ctxt_dim=2): class TTEnvGym(MujocoEnv, utils.EzPickle):
def __init__(self, ctxt_dim=2, fixed_goal=False):
model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml') model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml')
self.ctxt_dim = ctxt_dim self.ctxt_dim = ctxt_dim
self.fixed_goal = fixed_goal
if ctxt_dim == 2: if ctxt_dim == 2:
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM
if self.fixed_goal:
self.goal = np.array([-1, -0.1, 0])
else:
self.goal = np.zeros(3) # 2 x,y + 1z self.goal = np.zeros(3) # 2 x,y + 1z
elif ctxt_dim == 4: elif ctxt_dim == 4:
self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM
@ -37,9 +42,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
else: else:
raise ValueError("either 2 or 4 dimensional Contexts available") raise ValueError("either 2 or 4 dimensional Contexts available")
action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2]) # has no effect as it is overwritten in init of super
action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2]) # action_space_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64') # action_space_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
# self.action_space = spaces.Box(low=action_space_low, high=action_space_high, dtype='float64')
self.time_steps = 0 self.time_steps = 0
self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0]) self.init_qpos_tt = np.array([0, 0, 0, 1.5, 0, 0, 1.5, 0, 0, 0])
@ -47,10 +53,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
self.reward_func = TT_Reward(self.ctxt_dim) self.reward_func = TT_Reward(self.ctxt_dim)
self.ball_landing_pos = None self.ball_landing_pos = None
self.hited_ball = False self.hit_ball = False
self.ball_contact_after_hit = False self.ball_contact_after_hit = False
self._ids_set = False self._ids_set = False
super(TT_Env_Gym, self).__init__(model_path=model_path, frame_skip=1) super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func. self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT] self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME] self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
@ -77,14 +83,17 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
return obs return obs
def sample_context(self): def sample_context(self):
return np.random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim) return self.np_random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim)
def reset_model(self): def reset_model(self):
self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state self.set_state(self.init_qpos_tt, self.init_qvel_tt) # reset to initial sim state
self.time_steps = 0 self.time_steps = 0
self.ball_landing_pos = None self.ball_landing_pos = None
self.hited_ball = False self.hit_ball = False
self.ball_contact_after_hit = False self.ball_contact_after_hit = False
if self.fixed_goal:
self.goal = self.goal[:2]
else:
self.goal = self.sample_context()[:2] self.goal = self.sample_context()[:2]
if self.ctxt_dim == 2: if self.ctxt_dim == 2:
initial_ball_state = ball_init(random=False) # fixed velocity, fixed position initial_ball_state = ball_init(random=False) # fixed velocity, fixed position
@ -122,12 +131,12 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
if not self._ids_set: if not self._ids_set:
self._set_ids() self._set_ids()
done = False done = False
episode_end = False if self.time_steps+1<MAX_EPISODE_STEPS else True episode_end = False if self.time_steps + 1 < MAX_EPISODE_STEPS else True
if not self.hited_ball: if not self.hit_ball:
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side
if not self.hited_ball: if not self.hit_ball:
self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side
if self.hited_ball: if self.hit_ball:
if not self.ball_contact_after_hit: if not self.ball_contact_after_hit:
if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor if self._contact_checker(self.ball_contact_id, self.floor_contact_id): # first check contact with floor
self.ball_contact_after_hit = True self.ball_contact_after_hit = True
@ -140,7 +149,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
if self.ball_landing_pos is not None: if self.ball_landing_pos is not None:
done = True done = True
episode_end =True episode_end =True
reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hited_ball, self.ball_landing_pos) reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hit_ball, self.ball_landing_pos)
self.time_steps += 1 self.time_steps += 1
# gravity compensation on joints: # gravity compensation on joints:
#action += self.sim.data.qfrc_bias[:7].copy() #action += self.sim.data.qfrc_bias[:7].copy()
@ -151,7 +160,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle):
done = True done = True
reward = -25 reward = -25
ob = self._get_obs() ob = self._get_obs()
return ob, reward, done, {"hit_ball":self.hited_ball}# might add some information here .... info = {"hit_ball": self.hit_ball,
"q_pos": np.copy(self.sim.data.qpos[:7]),
"ball_pos": np.copy(self.sim.data.qpos[7:])}
return ob, reward, done, info # might add some information here ....
def set_context(self, context): def set_context(self, context):
old_state = self.sim.get_state() old_state = self.sim.get_state()

View File

@ -19,7 +19,7 @@ class TT_Reward:
# # seems to work for episodic case # # seems to work for episodic case
min_r_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj) - np.array(self.c_racket_traj), axis=1)) min_r_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj) - np.array(self.c_racket_traj), axis=1))
if not hited_ball: if not hited_ball:
return 0.2 * (1- np.tanh(min_r_b_dist**2)) return 0.2 * (1 - np.tanh(min_r_b_dist**2))
else: else:
if ball_landing_pos is None: if ball_landing_pos is None:
min_b_des_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj)[:,:2] - self.c_goal[:2], axis=1)) min_b_des_b_dist = np.min(np.linalg.norm(np.array(self.c_ball_traj)[:,:2] - self.c_goal[:2], axis=1))

View File

@ -11,9 +11,9 @@ environments in order to use our Motion Primitive gym interface with them.
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension |Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|---|---|---|---|---| |---|---|---|---|---|
|`dmc_ball_in_cup-catch_detpmp-v0`| A DetPmP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2 |`dmc_ball_in_cup-catch_promp-v0`| A ProMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000 | 10 | 2
|`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2 |`dmc_ball_in_cup-catch_dmp-v0`| A DMP wrapped version of the "catch" task for the "ball_in_cup" environment. | 1000| 10 | 2
|`dmc_reacher-easy_detpmp-v0`| A DetPmP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4 |`dmc_reacher-easy_promp-v0`| A ProMP wrapped version of the "easy" task for the "reacher" environment. | 1000 | 10 | 4
|`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4 |`dmc_reacher-easy_dmp-v0`| A DMP wrapped version of the "easy" task for the "reacher" environment. | 1000| 10 | 4
|`dmc_reacher-hard_detpmp-v0`| A DetPmP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4 |`dmc_reacher-hard_promp-v0`| A ProMP wrapped version of the "hard" task for the "reacher" environment.| 1000 | 10 | 4
|`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4 |`dmc_reacher-hard_dmp-v0`| A DMP wrapped version of the "hard" task for the "reacher" environment. | 1000 | 10 | 4

View File

@ -1,6 +1,6 @@
from . import manipulation, suite from . import manipulation, suite
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
from gym.envs.registration import register from gym.envs.registration import register
@ -34,8 +34,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
register( register(
id=f'dmc_ball_in_cup-catch_detpmp-v0', id=f'dmc_ball_in_cup-catch_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"ball_in_cup-catch", "name": f"ball_in_cup-catch",
"time_limit": 20, "time_limit": 20,
@ -45,7 +45,6 @@ register(
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
@ -55,7 +54,7 @@ register(
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_ball_in_cup-catch_detpmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
register( register(
id=f'dmc_reacher-easy_dmp-v0', id=f'dmc_reacher-easy_dmp-v0',
@ -86,8 +85,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
register( register(
id=f'dmc_reacher-easy_detpmp-v0', id=f'dmc_reacher-easy_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"reacher-easy", "name": f"reacher-easy",
"time_limit": 20, "time_limit": 20,
@ -97,7 +96,6 @@ register(
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
@ -108,7 +106,7 @@ register(
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-easy_detpmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
register( register(
id=f'dmc_reacher-hard_dmp-v0', id=f'dmc_reacher-hard_dmp-v0',
@ -139,8 +137,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
register( register(
id=f'dmc_reacher-hard_detpmp-v0', id=f'dmc_reacher-hard_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"reacher-hard", "name": f"reacher-hard",
"time_limit": 20, "time_limit": 20,
@ -150,7 +148,6 @@ register(
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
@ -161,7 +158,7 @@ register(
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_reacher-hard_detpmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-hard_promp-v0")
_dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"] _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
@ -196,10 +193,10 @@ for _task in _dmc_cartpole_tasks:
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-{_task}_detpmp-v0' _env_id = f'dmc_cartpole-{_task}_promp-v0'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"cartpole-{_task}", "name": f"cartpole-{_task}",
# "time_limit": 1, # "time_limit": 1,
@ -210,7 +207,6 @@ for _task in _dmc_cartpole_tasks:
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
@ -221,7 +217,7 @@ for _task in _dmc_cartpole_tasks:
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'dmc_cartpole-two_poles_dmp-v0' _env_id = f'dmc_cartpole-two_poles_dmp-v0'
register( register(
@ -253,10 +249,10 @@ register(
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-two_poles_detpmp-v0' _env_id = f'dmc_cartpole-two_poles_promp-v0'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"cartpole-two_poles", "name": f"cartpole-two_poles",
# "time_limit": 1, # "time_limit": 1,
@ -267,7 +263,6 @@ register(
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
@ -278,7 +273,7 @@ register(
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_env_id = f'dmc_cartpole-three_poles_dmp-v0' _env_id = f'dmc_cartpole-three_poles_dmp-v0'
register( register(
@ -310,10 +305,10 @@ register(
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'dmc_cartpole-three_poles_detpmp-v0' _env_id = f'dmc_cartpole-three_poles_promp-v0'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"cartpole-three_poles", "name": f"cartpole-three_poles",
# "time_limit": 1, # "time_limit": 1,
@ -324,7 +319,6 @@ register(
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
"width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
@ -335,7 +329,7 @@ register(
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# DeepMind Manipulation # DeepMind Manipulation
@ -364,8 +358,8 @@ register(
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
register( register(
id=f'dmc_manipulation-reach_site_detpmp-v0', id=f'dmc_manipulation-reach_site_promp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": f"manipulation-reach_site_features", "name": f"manipulation-reach_site_features",
# "time_limit": 1, # "time_limit": 1,
@ -375,11 +369,10 @@ register(
"num_dof": 9, "num_dof": 9,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
"width": 0.025,
"policy_type": "velocity", "policy_type": "velocity",
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
} }
} }
) )
ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("dmc_manipulation-reach_site_detpmp-v0") ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_manipulation-reach_site_promp-v0")

View File

@ -84,7 +84,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
} }
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples): # OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
# env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
# This renders the full MP trajectory # This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory. # It is only required to call render() once in the beginning, which renders every consecutive trajectory.
@ -128,7 +128,7 @@ if __name__ == '__main__':
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render) example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
# Gym + DMC hybrid task provided in the MP framework # Gym + DMC hybrid task provided in the MP framework
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render) example_dmc("dmc_ball_in_cup-catch_promp-v0", seed=10, iterations=1, render=render)
# Custom DMC task # Custom DMC task
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above # Different seed, because the episode is longer for this example and the name+seed combo is already registered above

View File

@ -76,7 +76,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"policy_type": "metaworld", # custom controller type for metaworld environments "policy_type": "metaworld", # custom controller type for metaworld environments
} }
env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a DMP (other mp_kwargs are required, see dmc_examples): # OR for a DMP (other mp_kwargs are required, see dmc_examples):
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
@ -122,7 +122,7 @@ if __name__ == '__main__':
example_dmc("button-press-v2", seed=10, iterations=500, render=render) example_dmc("button-press-v2", seed=10, iterations=500, render=render)
# MP + MetaWorld hybrid task provided in the our framework # MP + MetaWorld hybrid task provided in the our framework
example_dmc("ButtonPressDetPMP-v2", seed=10, iterations=1, render=render) example_dmc("ButtonPressProMP-v2", seed=10, iterations=1, render=render)
# Custom MetaWorld task # Custom MetaWorld task
example_custom_dmc_and_mp(seed=10, iterations=1, render=render) example_custom_dmc_and_mp(seed=10, iterations=1, render=render)

View File

@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
} }
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a deterministic ProMP: # OR for a deterministic ProMP:
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
if render: if render:
env.render(mode="human") env.render(mode="human")
@ -147,7 +147,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__': if __name__ == '__main__':
render = True render = False
# DMP # DMP
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render) example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)

View File

@ -6,7 +6,7 @@ def example_mp(env_name, seed=1):
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
For more information on motion primitive specific stuff, look at the mp examples. For more information on motion primitive specific stuff, look at the mp examples.
Args: Args:
env_name: DetPMP env_id env_name: ProMP env_id
seed: seed seed: seed
Returns: Returns:
@ -35,7 +35,7 @@ if __name__ == '__main__':
# example_mp("ReacherDMP-v2") # example_mp("ReacherDMP-v2")
# DetProMP # DetProMP
example_mp("ContinuousMountainCarDetPMP-v0") example_mp("ContinuousMountainCarProMP-v0")
example_mp("ReacherDetPMP-v2") example_mp("ReacherProMP-v2")
example_mp("FetchReachDenseDetPMP-v1") example_mp("FetchReachDenseProMP-v1")
example_mp("FetchSlideDenseDetPMP-v1") example_mp("FetchSlideDenseProMP-v1")

View File

@ -2,7 +2,7 @@ import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from alr_envs import dmc, meta from alr_envs import dmc, meta
from alr_envs.utils.make_env_helpers import make_detpmp_env from alr_envs.utils.make_env_helpers import make_promp_env
# This might work for some environments, however, please verify either way the correct trajectory information # This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below # for your environment are extracted below
@ -26,7 +26,7 @@ mp_kwargs = {
kwargs = dict(time_limit=2, episode_length=100) kwargs = dict(time_limit=2, episode_length=100)
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
**kwargs) **kwargs)
# Plot difference between real trajectory and target MP trajectory # Plot difference between real trajectory and target MP trajectory

View File

@ -3,7 +3,7 @@ from gym import register
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \ from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
object_change_mp_wrapper object_change_mp_wrapper
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# MetaWorld # MetaWorld
@ -12,10 +12,10 @@ _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "pl
for _task in _goal_change_envs: for _task in _goal_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_change_mp_wrapper.MPWrapper], "wrappers": [goal_change_mp_wrapper.MPWrapper],
@ -24,22 +24,21 @@ for _task in _goal_change_envs:
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.025,
"zero_start": True, "zero_start": True,
"policy_type": "metaworld", "policy_type": "metaworld",
} }
} }
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"] _object_change_envs = ["bin-picking-v2", "hammer-v2", "sweep-into-v2"]
for _task in _object_change_envs: for _task in _object_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [object_change_mp_wrapper.MPWrapper], "wrappers": [object_change_mp_wrapper.MPWrapper],
@ -48,13 +47,12 @@ for _task in _object_change_envs:
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.025,
"zero_start": True, "zero_start": True,
"policy_type": "metaworld", "policy_type": "metaworld",
} }
} }
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2", _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press-wall-v2", "button-press-topdown-v2",
"button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2", "button-press-topdown-wall-v2", "coffee-button-v2", "coffee-pull-v2",
@ -70,10 +68,10 @@ _goal_and_object_change_envs = ["box-close-v2", "button-press-v2", "button-press
for _task in _goal_and_object_change_envs: for _task in _goal_and_object_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_object_change_mp_wrapper.MPWrapper], "wrappers": [goal_object_change_mp_wrapper.MPWrapper],
@ -82,22 +80,21 @@ for _task in _goal_and_object_change_envs:
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.025,
"zero_start": True, "zero_start": True,
"policy_type": "metaworld", "policy_type": "metaworld",
} }
} }
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
_goal_and_endeffector_change_envs = ["basketball-v2"] _goal_and_endeffector_change_envs = ["basketball-v2"]
for _task in _goal_and_endeffector_change_envs: for _task in _goal_and_endeffector_change_envs:
task_id_split = _task.split("-") task_id_split = _task.split("-")
name = "".join([s.capitalize() for s in task_id_split[:-1]]) name = "".join([s.capitalize() for s in task_id_split[:-1]])
_env_id = f'{name}DetPMP-{task_id_split[-1]}' _env_id = f'{name}ProMP-{task_id_split[-1]}'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper], "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
@ -106,10 +103,9 @@ for _task in _goal_and_endeffector_change_envs:
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.025,
"zero_start": True, "zero_start": True,
"policy_type": "metaworld", "policy_type": "metaworld",
} }
} }
) )
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append(_env_id) ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)

View File

@ -8,7 +8,7 @@ These environments are wrapped-versions of their OpenAI-gym counterparts.
|Name| Description|Trajectory Horizon|Action Dimension|Context Dimension |Name| Description|Trajectory Horizon|Action Dimension|Context Dimension
|---|---|---|---|---| |---|---|---|---|---|
|`ContinuousMountainCarDetPMP-v0`| A DetPmP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1 |`ContinuousMountainCarProMP-v0`| A ProMP wrapped version of the ContinuousMountainCar-v0 environment. | 100 | 1
|`ReacherDetPMP-v2`| A DetPmP wrapped version of the Reacher-v2 environment. | 50 | 2 |`ReacherProMP-v2`| A ProMP wrapped version of the Reacher-v2 environment. | 50 | 2
|`FetchSlideDenseDetPMP-v1`| A DetPmP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4 |`FetchSlideDenseProMP-v1`| A ProMP wrapped version of the FetchSlideDense-v1 environment. | 50 | 4
|`FetchReachDenseDetPMP-v1`| A DetPmP wrapped version of the FetchReachDense-v1 environment. | 50 | 4 |`FetchReachDenseProMP-v1`| A ProMP wrapped version of the FetchReachDense-v1 environment. | 50 | 4

View File

@ -3,7 +3,7 @@ from gym.wrappers import FlattenObservation
from . import classic_control, mujoco, robotics from . import classic_control, mujoco, robotics
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "DetPMP": []} ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
# Short Continuous Mountain Car # Short Continuous Mountain Car
register( register(
@ -16,8 +16,8 @@ register(
# Open AI # Open AI
# Classic Control # Classic Control
register( register(
id='ContinuousMountainCarDetPMP-v1', id='ContinuousMountainCarProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:MountainCarContinuous-v1", "name": "alr_envs:MountainCarContinuous-v1",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper], "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
@ -26,7 +26,6 @@ register(
"num_basis": 4, "num_basis": 4,
"duration": 2, "duration": 2,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "motor", "policy_type": "motor",
"policy_kwargs": { "policy_kwargs": {
@ -36,11 +35,11 @@ register(
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v1") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v1")
register( register(
id='ContinuousMountainCarDetPMP-v0', id='ContinuousMountainCarProMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.classic_control:MountainCarContinuous-v0", "name": "gym.envs.classic_control:MountainCarContinuous-v0",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper], "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
@ -49,7 +48,6 @@ register(
"num_basis": 4, "num_basis": 4,
"duration": 19.98, "duration": 19.98,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "motor", "policy_type": "motor",
"policy_kwargs": { "policy_kwargs": {
@ -59,11 +57,11 @@ register(
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ContinuousMountainCarDetPMP-v0") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v0")
register( register(
id='ReacherDetPMP-v2', id='ReacherProMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.mujoco:Reacher-v2", "name": "gym.envs.mujoco:Reacher-v2",
"wrappers": [mujoco.reacher_v2.MPWrapper], "wrappers": [mujoco.reacher_v2.MPWrapper],
@ -72,7 +70,6 @@ register(
"num_basis": 6, "num_basis": 6,
"duration": 1, "duration": 1,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "motor", "policy_type": "motor",
"policy_kwargs": { "policy_kwargs": {
@ -82,11 +79,11 @@ register(
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("ReacherDetPMP-v2") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
register( register(
id='FetchSlideDenseDetPMP-v1', id='FetchSlideDenseProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchSlideDense-v1", "name": "gym.envs.robotics:FetchSlideDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -95,17 +92,16 @@ register(
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "position" "policy_type": "position"
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDenseDetPMP-v1") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideDenseProMP-v1")
register( register(
id='FetchSlideDetPMP-v1', id='FetchSlideProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchSlide-v1", "name": "gym.envs.robotics:FetchSlide-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -114,17 +110,16 @@ register(
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "position" "policy_type": "position"
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchSlideDetPMP-v1") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchSlideProMP-v1")
register( register(
id='FetchReachDenseDetPMP-v1', id='FetchReachDenseProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchReachDense-v1", "name": "gym.envs.robotics:FetchReachDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -133,17 +128,16 @@ register(
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "position" "policy_type": "position"
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDenseDetPMP-v1") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachDenseProMP-v1")
register( register(
id='FetchReachDetPMP-v1', id='FetchReachProMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchReach-v1", "name": "gym.envs.robotics:FetchReach-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
@ -152,10 +146,9 @@ register(
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"post_traj_time": 0, "post_traj_time": 0,
"width": 0.02,
"zero_start": True, "zero_start": True,
"policy_type": "position" "policy_type": "position"
} }
} }
) )
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"].append("FetchReachDetPMP-v1") ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")

View File

@ -1,3 +1,4 @@
import warnings
from typing import Iterable, Type, Union from typing import Iterable, Type, Union
import gym import gym
@ -5,7 +6,6 @@ import numpy as np
from gym.envs.registration import EnvSpec from gym.envs.registration import EnvSpec
from mp_env_api import MPEnvWrapper from mp_env_api import MPEnvWrapper
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
@ -48,6 +48,11 @@ def make(env_id: str, seed, **kwargs):
Returns: Gym environment Returns: Gym environment
""" """
if any([det_pmp in env_id for det_pmp in ["DetPMP", "detpmp"]]):
warnings.warn("DetPMP is deprecated and converted to ProMP")
env_id = env_id.replace("DetPMP", "ProMP")
env_id = env_id.replace("detpmp", "promp")
try: try:
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment. # Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
if env_id.startswith("dmc"): if env_id.startswith("dmc"):
@ -153,26 +158,6 @@ def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwar
return ProMPWrapper(_env, **mp_kwargs) return ProMPWrapper(_env, **mp_kwargs)
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
"""
This can also be used standalone for manually building a custom Det ProMP environment.
Args:
env_id: base_env_name,
wrappers: list of wrappers (at least an MPEnvWrapper),
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
Returns: Det ProMP wrapped gym env
"""
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
_verify_dof(_env, mp_kwargs.get("num_dof"))
return DetPMPWrapper(_env, **mp_kwargs)
def make_dmp_env_helper(**kwargs): def make_dmp_env_helper(**kwargs):
""" """
Helper function for registering a DMP gym environments. Helper function for registering a DMP gym environments.
@ -212,26 +197,6 @@ def make_promp_env_helper(**kwargs):
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def make_detpmp_env_helper(**kwargs):
"""
Helper function for registering ProMP gym environments.
This can also be used standalone for manually building a custom ProMP environment.
Args:
**kwargs: expects at least the following:
{
"name": base_env_name,
"wrappers": list of wrappers (at least an MPEnvWrapper),
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
}
Returns: DMP wrapped gym env
"""
seed = kwargs.pop("seed", None)
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
""" """
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives. When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.

View File

@ -98,8 +98,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']: for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
@ -110,8 +110,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']: for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
@ -122,8 +122,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']: for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
@ -134,8 +134,8 @@ class TestMPEnvironments(unittest.TestCase):
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']: for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['ProMP']:
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
@ -143,29 +143,29 @@ class TestMPEnvironments(unittest.TestCase):
"""Tests that identical seeds produce identical trajectories for ALR MP Envs.""" """Tests that identical seeds produce identical trajectories for ALR MP Envs."""
with self.subTest(msg="DMP"): with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"]) self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"]) self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_openai_environment_determinism(self): def test_openai_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs.""" """Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
with self.subTest(msg="DMP"): with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"]) self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"]) self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_dmc_environment_determinism(self): def test_dmc_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for DMC MP Envs.""" """Tests that identical seeds produce identical trajectories for DMC MP Envs."""
with self.subTest(msg="DMP"): with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"]) self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"]) self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
def test_metaworld_environment_determinism(self): def test_metaworld_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs.""" """Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
with self.subTest(msg="DMP"): with self.subTest(msg="DMP"):
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"]) self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
with self.subTest(msg="DetPMP"): with self.subTest(msg="ProMP"):
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"]) self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -81,13 +81,13 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
def _verify_done(self, done): def _verify_done(self, done):
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.") self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
def test_dmc_functionality(self): def test_metaworld_functionality(self):
"""Tests that environments runs without errors using random actions.""" """Tests that environments runs without errors using random actions."""
for env_id in ALL_ENVS: for env_id in ALL_ENVS:
with self.subTest(msg=env_id): with self.subTest(msg=env_id):
self._run_env(env_id) self._run_env(env_id)
def test_dmc_determinism(self): def test_metaworld_determinism(self):
"""Tests that identical seeds produce identical trajectories.""" """Tests that identical seeds produce identical trajectories."""
seed = 0 seed = 0
# Iterate over two trajectories, which should have the same state and action sequence # Iterate over two trajectories, which should have the same state and action sequence