diff --git a/README.md b/README.md
index 3a9f826..4db4f69 100644
--- a/README.md
+++ b/README.md
@@ -46,7 +46,7 @@ pip install -e .
In case you want to use dm_control oder metaworld, you can install them by specifying extras
```bash
-pip install -e .[dmc, metaworld]
+pip install -e .[dmc,metaworld]
```
> **Note:**
@@ -205,7 +205,7 @@ at the [examples](fancy_gym/examples/).
import fancy_gym
# Base environment name, according to structure of above example
-base_env_id = "ball_in_cup-catch"
+base_env_id = "dmc:ball_in_cup-catch"
# Replace this wrapper with the custom wrapper for your environment by inheriting from the RawInferfaceWrapper.
# You can also add other gym.Wrappers in case they are needed,
diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py
index a567d16..9619954 100644
--- a/fancy_gym/black_box/black_box_wrapper.py
+++ b/fancy_gym/black_box/black_box_wrapper.py
@@ -21,7 +21,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
learn_sub_trajectories: bool = False,
replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
- reward_aggregation: Callable[[np.ndarray], float] = np.sum
+ reward_aggregation: Callable[[np.ndarray], float] = np.sum,
+ max_planning_times: int = np.inf,
+ condition_on_desired: bool = False
):
"""
gym.Wrapper for leveraging a black box approach with a trajectory generator.
@@ -66,6 +68,14 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.render_kwargs = {}
self.verbose = verbose
+ # condition value
+ self.condition_on_desired = condition_on_desired
+ self.condition_pos = None
+ self.condition_vel = None
+
+ self.max_planning_times = max_planning_times
+ self.plan_steps = 0
+
def observation(self, observation):
# return context space if we are
if self.return_context_observation:
@@ -83,10 +93,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
self.traj_gen.set_params(clipped_params)
- bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
- # TODO we could think about initializing with the previous desired value in order to have a smooth transition
- # at least from the planning point of view.
- self.traj_gen.set_initial_conditions(bc_time, self.current_pos, self.current_vel)
+ init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
+
+ condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
+ condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
+
+ self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
position = get_numpy(self.traj_gen.get_traj_pos())
@@ -144,6 +156,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict()
done = False
+ self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
@@ -162,8 +175,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if self.render_kwargs:
self.env.render(**self.render_kwargs)
- if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
- t + 1 + self.current_traj_steps):
+ if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
+ t + 1 + self.current_traj_steps)
+ and self.plan_steps < self.max_planning_times):
+
+ self.condition_pos = pos if self.condition_on_desired else None
+ self.condition_vel = vel if self.condition_on_desired else None
+
break
infos.update({k: v[:t + 1] for k, v in infos.items()})
@@ -187,5 +205,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
self.current_traj_steps = 0
+ self.plan_steps = 0
self.traj_gen.reset()
+ self.condition_vel = None
+ self.condition_pos = None
return super(BlackBoxWrapper, self).reset()
diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py
index 2e8d61e..d504990 100644
--- a/fancy_gym/envs/__init__.py
+++ b/fancy_gym/envs/__init__.py
@@ -16,6 +16,8 @@ from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW
from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET
from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
+from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
+ BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
@@ -36,7 +38,8 @@ DEFAULT_BB_DICT_ProMP = {
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
- 'num_basis_zero_start': 1
+ 'num_basis_zero_start': 1,
+ 'basis_bandwidth_factor': 3.0,
}
}
@@ -60,6 +63,29 @@ DEFAULT_BB_DICT_DMP = {
}
}
+DEFAULT_BB_DICT_ProDMP = {
+ "name": 'EnvName',
+ "wrappers": [],
+ "trajectory_generator_kwargs": {
+ 'trajectory_generator_type': 'prodmp',
+ },
+ "phase_generator_kwargs": {
+ 'phase_generator_type': 'exp',
+ },
+ "controller_kwargs": {
+ 'controller_type': 'motor',
+ "p_gains": 1.0,
+ "d_gains": 0.1,
+ },
+ "basis_generator_kwargs": {
+ 'basis_generator_type': 'prodmp',
+ 'alpha': 10,
+ 'num_basis': 5,
+ },
+ "black_box_kwargs": {
+ }
+}
+
# Classic Control
## Simple Reacher
register(
@@ -197,6 +223,14 @@ register(
max_episode_steps=MAX_EPISODE_STEPS_BEERPONG,
)
+# Box pushing environments with different rewards
+for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
+ register(
+ id='BoxPushing{}-v0'.format(reward_type),
+ entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
+ max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
+ )
+
# Here we use the same reward as in BeerPong-v0, but now consider after the release,
# only one time step, i.e. we simulate until the end of th episode
register(
@@ -325,7 +359,6 @@ for _v in _versions:
kwargs=kwargs_dict_reacher_promp
)
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
-
########################################################################################################################
## Beerpong ProMP
_versions = ['BeerPong-v0']
@@ -430,7 +463,50 @@ for _v in _versions:
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# ########################################################################################################################
-#
+
+## Box Pushing
+_versions = ['BoxPushingDense-v0', 'BoxPushingTemporalSparse-v0', 'BoxPushingTemporalSpatialSparse-v0']
+for _v in _versions:
+ _name = _v.split("-")
+ _env_id = f'{_name[0]}ProMP-{_name[1]}'
+ kwargs_dict_box_pushing_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
+ kwargs_dict_box_pushing_promp['wrappers'].append(mujoco.box_pushing.MPWrapper)
+ kwargs_dict_box_pushing_promp['name'] = _v
+ kwargs_dict_box_pushing_promp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
+ kwargs_dict_box_pushing_promp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
+ kwargs_dict_box_pushing_promp['basis_generator_kwargs']['basis_bandwidth_factor'] = 2 # 3.5, 4 to try
+
+ register(
+ id=_env_id,
+ entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
+ kwargs=kwargs_dict_box_pushing_promp
+ )
+ ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
+
+for _v in _versions:
+ _name = _v.split("-")
+ _env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'
+ kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
+ kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
+ kwargs_dict_box_pushing_prodmp['name'] = _v
+ kwargs_dict_box_pushing_prodmp['controller_kwargs']['p_gains'] = 0.01 * np.array([120., 120., 120., 120., 50., 30., 10.])
+ kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
+ kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3
+ kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3
+ kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
+ kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
+ kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
+ kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
+ kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
+ kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 4
+ kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
+ kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desired'] = True
+ register(
+ id=_env_id,
+ entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
+ kwargs=kwargs_dict_box_pushing_prodmp
+ )
+ ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
#
# ## Walker2DJump
# _versions = ['Walker2DJump-v0']
diff --git a/fancy_gym/envs/mujoco/__init__.py b/fancy_gym/envs/mujoco/__init__.py
index 840691f..3254b4d 100644
--- a/fancy_gym/envs/mujoco/__init__.py
+++ b/fancy_gym/envs/mujoco/__init__.py
@@ -7,3 +7,4 @@ from .hopper_throw.hopper_throw import HopperThrowEnv
from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
+from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
diff --git a/fancy_gym/envs/mujoco/box_pushing/__init__.py b/fancy_gym/envs/mujoco/box_pushing/__init__.py
new file mode 100644
index 0000000..c5e6d2f
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/__init__.py
@@ -0,0 +1 @@
+from .mp_wrapper import MPWrapper
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/box_pushing.xml b/fancy_gym/envs/mujoco/box_pushing/assets/box_pushing.xml
new file mode 100644
index 0000000..516482f
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/box_pushing.xml
@@ -0,0 +1,42 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/kit_lab_surrounding.xml b/fancy_gym/envs/mujoco/box_pushing/assets/kit_lab_surrounding.xml
new file mode 100644
index 0000000..ca60a9c
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/kit_lab_surrounding.xml
@@ -0,0 +1,118 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/d435v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/d435v.stl
new file mode 100644
index 0000000..809a0da
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/d435v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/finger.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/finger.stl
new file mode 100644
index 0000000..3b87289
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/finger.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/fingerv.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/fingerv.stl
new file mode 100644
index 0000000..0b11382
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/fingerv.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/hand.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/hand.stl
new file mode 100644
index 0000000..4e82090
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/hand.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/handv.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/handv.stl
new file mode 100644
index 0000000..92f60bd
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/handv.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0.stl
new file mode 100644
index 0000000..def070c
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0v.stl
new file mode 100644
index 0000000..72dba55
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link0v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1.stl
new file mode 100644
index 0000000..426bcf2
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1v.stl
new file mode 100644
index 0000000..b42f97b
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link1v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2.stl
new file mode 100644
index 0000000..b369f15
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2v.stl
new file mode 100644
index 0000000..d72bbe4
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link2v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3.stl
new file mode 100644
index 0000000..25162ee
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3v.stl
new file mode 100644
index 0000000..904de9d
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link3v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4.stl
new file mode 100644
index 0000000..76c8c33
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4v.stl
new file mode 100644
index 0000000..da74ed6
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link4v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5.stl
new file mode 100644
index 0000000..3006a0b
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5v.stl
new file mode 100644
index 0000000..f795374
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link5v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6.stl
new file mode 100644
index 0000000..2e9594a
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6v.stl
new file mode 100644
index 0000000..8b2a7f3
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link6v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7.stl
new file mode 100644
index 0000000..0532d05
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7v.stl b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7v.stl
new file mode 100644
index 0000000..82b5946
Binary files /dev/null and b/fancy_gym/envs/mujoco/box_pushing/assets/meshes/panda/link7v.stl differ
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/panda_rod.xml b/fancy_gym/envs/mujoco/box_pushing/assets/panda_rod.xml
new file mode 100644
index 0000000..ac85629
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/panda_rod.xml
@@ -0,0 +1,159 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/push_box.xml b/fancy_gym/envs/mujoco/box_pushing/assets/push_box.xml
new file mode 100644
index 0000000..25313f8
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/push_box.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_mocap_control.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_mocap_control.xml
new file mode 100644
index 0000000..a36c987
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_mocap_control.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_position_control.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_position_control.xml
new file mode 100644
index 0000000..c83939a
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_position_control.xml
@@ -0,0 +1,50 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_torque_control.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_torque_control.xml
new file mode 100644
index 0000000..7b70a75
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_torque_control.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_velocity_control.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_velocity_control.xml
new file mode 100644
index 0000000..4fda9de
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/controller/panda_velocity_control.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/depr/panda_gym.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/depr/panda_gym.xml
new file mode 100644
index 0000000..27f3415
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/depr/panda_gym.xml
@@ -0,0 +1,157 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda.xml
new file mode 100644
index 0000000..8718e42
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda.xml
@@ -0,0 +1,155 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_bimanual.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_bimanual.xml
new file mode 100644
index 0000000..b970726
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_bimanual.xml
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_mocap.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_mocap.xml
new file mode 100644
index 0000000..7de5c38
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/panda_mocap.xml
@@ -0,0 +1,140 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/robots/pandas_kit_lab.xml b/fancy_gym/envs/mujoco/box_pushing/assets/robots/pandas_kit_lab.xml
new file mode 100644
index 0000000..b970726
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/robots/pandas_kit_lab.xml
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/fancy_gym/envs/mujoco/box_pushing/assets/surroundings/base.xml b/fancy_gym/envs/mujoco/box_pushing/assets/surroundings/base.xml
new file mode 100644
index 0000000..7a40fbd
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/assets/surroundings/base.xml
@@ -0,0 +1,19 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py
new file mode 100644
index 0000000..275bba1
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py
@@ -0,0 +1,362 @@
+import os
+
+import numpy as np
+from gym import utils, spaces
+from gym.envs.mujoco import MujocoEnv
+from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
+from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
+from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
+
+import mujoco
+
+MAX_EPISODE_STEPS_BOX_PUSHING = 100
+
+BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]])
+
+class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
+ """
+ franka box pushing environment
+ action space:
+ normalized joints torque * 7 , range [-1, 1]
+ observation space:
+
+ rewards:
+ 1. dense reward
+ 2. time-depend sparse reward
+ 3. time-spatial-depend sparse reward
+ """
+
+ def __init__(self, frame_skip: int = 10):
+ utils.EzPickle.__init__(**locals())
+ self._steps = 0
+ self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.])
+ self.init_qvel_box_pushing = np.zeros(15)
+ self.frame_skip = frame_skip
+
+ self._q_max = q_max
+ self._q_min = q_min
+ self._q_dot_max = q_dot_max
+ self._desired_rod_quat = desired_rod_quat
+
+ self._episode_energy = 0.
+ MujocoEnv.__init__(self,
+ model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
+ frame_skip=self.frame_skip,
+ mujoco_bindings="mujoco")
+ self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
+
+ def step(self, action):
+ action = 10 * np.clip(action, self.action_space.low, self.action_space.high)
+ resultant_action = np.clip(action + self.data.qfrc_bias[:7].copy(), -q_torque_max, q_torque_max)
+
+ unstable_simulation = False
+
+ try:
+ self.do_simulation(resultant_action, self.frame_skip)
+ except Exception as e:
+ print(e)
+ unstable_simulation = True
+
+ self._steps += 1
+ self._episode_energy += np.sum(np.square(action))
+
+ episode_end = True if self._steps >= MAX_EPISODE_STEPS_BOX_PUSHING else False
+
+ box_pos = self.data.body("box_0").xpos.copy()
+ box_quat = self.data.body("box_0").xquat.copy()
+ target_pos = self.data.body("replan_target_pos").xpos.copy()
+ target_quat = self.data.body("replan_target_pos").xquat.copy()
+ rod_tip_pos = self.data.site("rod_tip").xpos.copy()
+ rod_quat = self.data.body("push_rod").xquat.copy()
+ qpos = self.data.qpos[:7].copy()
+ qvel = self.data.qvel[:7].copy()
+
+ if not unstable_simulation:
+ reward = self._get_reward(episode_end, box_pos, box_quat, target_pos, target_quat,
+ rod_tip_pos, rod_quat, qpos, qvel, action)
+ else:
+ reward = -50
+
+ obs = self._get_obs()
+ box_goal_pos_dist = 0. if not episode_end else np.linalg.norm(box_pos - target_pos)
+ box_goal_quat_dist = 0. if not episode_end else rotation_distance(box_quat, target_quat)
+ infos = {
+ 'episode_end': episode_end,
+ 'box_goal_pos_dist': box_goal_pos_dist,
+ 'box_goal_rot_dist': box_goal_quat_dist,
+ 'episode_energy': 0. if not episode_end else self._episode_energy,
+ 'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
+ 'num_steps': self._steps
+ }
+ return obs, reward, episode_end, infos
+
+ def reset_model(self):
+ # rest box to initial position
+ self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
+ box_init_pos = np.array([0.4, 0.3, -0.01, 0.0, 0.0, 0.0, 1.0])
+ self.data.joint("box_joint").qpos = box_init_pos
+
+ # set target position
+ box_target_pos = self.sample_context()
+ while np.linalg.norm(box_target_pos[:2] - box_init_pos[:2]) < 0.3:
+ box_target_pos = self.sample_context()
+ # box_target_pos[0] = 0.4
+ # box_target_pos[1] = -0.3
+ # box_target_pos[-4:] = np.array([0.0, 0.0, 0.0, 1.0])
+ self.model.body_pos[2] = box_target_pos[:3]
+ self.model.body_quat[2] = box_target_pos[-4:]
+ self.model.body_pos[3] = box_target_pos[:3]
+ self.model.body_quat[3] = box_target_pos[-4:]
+
+ # set the robot to the right configuration (rod tip in the box)
+ desired_tcp_pos = box_init_pos[:3] + np.array([0.0, 0.0, 0.15])
+ desired_tcp_quat = np.array([0, 1, 0, 0])
+ desired_joint_pos = self.calculateOfflineIK(desired_tcp_pos, desired_tcp_quat)
+ self.data.qpos[:7] = desired_joint_pos
+
+ mujoco.mj_forward(self.model, self.data)
+ self._steps = 0
+ self._episode_energy = 0.
+
+ return self._get_obs()
+
+ def sample_context(self):
+ pos = self.np_random.uniform(low=BOX_POS_BOUND[0], high=BOX_POS_BOUND[1])
+ theta = self.np_random.uniform(low=0, high=np.pi * 2)
+ quat = rot_to_quat(theta, np.array([0, 0, 1]))
+ return np.concatenate([pos, quat])
+
+ def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
+ rod_tip_pos, rod_quat, qpos, qvel, action):
+ raise NotImplementedError
+
+ def _get_obs(self):
+ obs = np.concatenate([
+ self.data.qpos[:7].copy(), # joint position
+ self.data.qvel[:7].copy(), # joint velocity
+ # self.data.qfrc_bias[:7].copy(), # joint gravity compensation
+ # self.data.site("rod_tip").xpos.copy(), # position of rod tip
+ # self.data.body("push_rod").xquat.copy(), # orientation of rod
+ self.data.body("box_0").xpos.copy(), # position of box
+ self.data.body("box_0").xquat.copy(), # orientation of box
+ self.data.body("replan_target_pos").xpos.copy(), # position of target
+ self.data.body("replan_target_pos").xquat.copy() # orientation of target
+ ])
+ return obs
+
+ def _joint_limit_violate_penalty(self, qpos, qvel, enable_pos_limit=False, enable_vel_limit=False):
+ penalty = 0.
+ p_coeff = 1.
+ v_coeff = 1.
+ # q_limit
+ if enable_pos_limit:
+ higher_error = qpos - self._q_max
+ lower_error = self._q_min - qpos
+ penalty -= p_coeff * (abs(np.sum(higher_error[qpos > self._q_max])) +
+ abs(np.sum(lower_error[qpos < self._q_min])))
+ # q_dot_limit
+ if enable_vel_limit:
+ q_dot_error = abs(qvel) - abs(self._q_dot_max)
+ penalty -= v_coeff * abs(np.sum(q_dot_error[q_dot_error > 0.]))
+ return penalty
+
+ def get_body_jacp(self, name):
+ id = mujoco.mj_name2id(self.model, 1, name)
+ jacp = np.zeros((3, self.model.nv))
+ mujoco.mj_jacBody(self.model, self.data, jacp, None, id)
+ return jacp
+
+ def get_body_jacr(self, name):
+ id = mujoco.mj_name2id(self.model, 1, name)
+ jacr = np.zeros((3, self.model.nv))
+ mujoco.mj_jacBody(self.model, self.data, None, jacr, id)
+ return jacr
+
+ def calculateOfflineIK(self, desired_cart_pos, desired_cart_quat):
+ """
+ calculate offline inverse kinematics for franka pandas
+ :param desired_cart_pos: desired cartesian position of tool center point
+ :param desired_cart_quat: desired cartesian quaternion of tool center point
+ :return: joint angles
+ """
+ J_reg = 1e-6
+ w = np.diag([1, 1, 1, 1, 1, 1, 1])
+ target_theta_null = np.array([
+ 3.57795216e-09,
+ 1.74532920e-01,
+ 3.30500960e-08,
+ -8.72664630e-01,
+ -1.14096181e-07,
+ 1.22173047e00,
+ 7.85398126e-01])
+ eps = 1e-5 # threshold for convergence
+ IT_MAX = 1000
+ dt = 1e-3
+ i = 0
+ pgain = [
+ 33.9403713446798,
+ 30.9403713446798,
+ 33.9403713446798,
+ 27.69370238555632,
+ 33.98706171459314,
+ 30.9185531893281,
+ ]
+ pgain_null = 5 * np.array([
+ 7.675519770796831,
+ 2.676935478437176,
+ 8.539040163444975,
+ 1.270446361314313,
+ 8.87752182480855,
+ 2.186782233762969,
+ 4.414432577659688,
+ ])
+ pgain_limit = 20
+ q = self.data.qpos[:7].copy()
+ qd_d = np.zeros(q.shape)
+ old_err_norm = np.inf
+
+ while True:
+ q_old = q
+ q = q + dt * qd_d
+ q = np.clip(q, q_min, q_max)
+ self.data.qpos[:7] = q
+ mujoco.mj_forward(self.model, self.data)
+ current_cart_pos = self.data.body("tcp").xpos.copy()
+ current_cart_quat = self.data.body("tcp").xquat.copy()
+
+ cart_pos_error = np.clip(desired_cart_pos - current_cart_pos, -0.1, 0.1)
+
+ if np.linalg.norm(current_cart_quat - desired_cart_quat) > np.linalg.norm(current_cart_quat + desired_cart_quat):
+ current_cart_quat = -current_cart_quat
+ cart_quat_error = np.clip(get_quaternion_error(current_cart_quat, desired_cart_quat), -0.5, 0.5)
+
+ err = np.hstack((cart_pos_error, cart_quat_error))
+ err_norm = np.sum(cart_pos_error**2) + np.sum((current_cart_quat - desired_cart_quat)**2)
+ if err_norm > old_err_norm:
+ q = q_old
+ dt = 0.7 * dt
+ continue
+ else:
+ dt = 1.025 * dt
+
+ if err_norm < eps:
+ break
+ if i > IT_MAX:
+ break
+
+ old_err_norm = err_norm
+
+ ### get Jacobian by mujoco
+ self.data.qpos[:7] = q
+ mujoco.mj_forward(self.model, self.data)
+
+ jacp = self.get_body_jacp("tcp")[:, :7].copy()
+ jacr = self.get_body_jacr("tcp")[:, :7].copy()
+
+ J = np.concatenate((jacp, jacr), axis=0)
+
+ Jw = J.dot(w)
+
+ # J * W * J.T + J_reg * I
+ JwJ_reg = Jw.dot(J.T) + J_reg * np.eye(J.shape[0])
+
+ # Null space velocity, points to home position
+ qd_null = pgain_null * (target_theta_null - q)
+
+ margin_to_limit = 0.1
+ qd_null_limit = np.zeros(qd_null.shape)
+ qd_null_limit_max = pgain_limit * (q_max - margin_to_limit - q)
+ qd_null_limit_min = pgain_limit * (q_min + margin_to_limit - q)
+ qd_null_limit[q > q_max - margin_to_limit] += qd_null_limit_max[q > q_max - margin_to_limit]
+ qd_null_limit[q < q_min + margin_to_limit] += qd_null_limit_min[q < q_min + margin_to_limit]
+ qd_null += qd_null_limit
+
+ # W J.T (J W J' + reg I)^-1 xd_d + (I - W J.T (J W J' + reg I)^-1 J qd_null
+ qd_d = np.linalg.solve(JwJ_reg, pgain * err - J.dot(qd_null))
+
+ qd_d = w.dot(J.transpose()).dot(qd_d) + qd_null
+
+ i += 1
+
+ return q
+
+class BoxPushingDense(BoxPushingEnvBase):
+ def __init__(self, frame_skip: int = 10):
+ super(BoxPushingDense, self).__init__(frame_skip=frame_skip)
+ def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
+ rod_tip_pos, rod_quat, qpos, qvel, action):
+ joint_penalty = self._joint_limit_violate_penalty(qpos,
+ qvel,
+ enable_pos_limit=True,
+ enable_vel_limit=True)
+ tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100)
+ box_goal_pos_dist_reward = -3.5 * np.linalg.norm(box_pos - target_pos)
+ box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi
+ energy_cost = -0.0005 * np.sum(np.square(action))
+
+ reward = joint_penalty + tcp_box_dist_reward + \
+ box_goal_pos_dist_reward + box_goal_rot_dist_reward + energy_cost
+
+ rod_inclined_angle = rotation_distance(rod_quat, self._desired_rod_quat)
+ if rod_inclined_angle > np.pi / 4:
+ reward -= rod_inclined_angle / (np.pi)
+
+ return reward
+
+class BoxPushingTemporalSparse(BoxPushingEnvBase):
+ def __init__(self, frame_skip: int = 10):
+ super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip)
+
+ def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
+ rod_tip_pos, rod_quat, qpos, qvel, action):
+ reward = 0.
+ joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True)
+ energy_cost = -0.0005 * np.sum(np.square(action))
+ tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100)
+ reward += joint_penalty + tcp_box_dist_reward + energy_cost
+ rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat)
+
+ if rod_inclined_angle > np.pi / 4:
+ reward -= rod_inclined_angle / (np.pi)
+
+ if not episode_end:
+ return reward
+
+ box_goal_dist = np.linalg.norm(box_pos - target_pos)
+
+ box_goal_pos_dist_reward = -3.5 * box_goal_dist * 100
+ box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi * 100
+
+ reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward
+
+ return reward
+
+class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
+
+ def __init__(self, frame_skip: int = 10):
+ super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip)
+
+ def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
+ rod_tip_pos, rod_quat, qpos, qvel, action):
+ reward = 0.
+ joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True)
+ energy_cost = -0.0005 * np.sum(np.square(action))
+ tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100)
+ reward += joint_penalty + tcp_box_dist_reward + energy_cost
+ rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat)
+
+ if rod_inclined_angle > np.pi / 4:
+ reward -= rod_inclined_angle / (np.pi)
+
+ if not episode_end:
+ return reward
+
+ box_goal_dist = np.linalg.norm(box_pos - target_pos)
+
+ if box_goal_dist < 0.1:
+ reward += 300
+ box_goal_pos_dist_reward = np.clip(- 3.5 * box_goal_dist * 100 * 3, -100, 0)
+ box_goal_rot_dist_reward = np.clip(- rotation_distance(box_quat, target_quat)/np.pi * 100 * 1.5, -100, 0)
+ reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward
+
+ return reward
diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py
new file mode 100644
index 0000000..0b1919e
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py
@@ -0,0 +1,53 @@
+import numpy as np
+
+
+# joint constraints for Franka robot
+q_max = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973])
+q_min = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973])
+
+q_dot_max = np.array([2.1750, 2.1750, 2.1750, 2.1750, 2.6100, 2.6100, 2.6100])
+q_torque_max = np.array([90., 90., 90., 90., 12., 12., 12.])
+#
+desired_rod_quat = np.array([0.0, 1.0, 0.0, 0.0])
+
+def skew(x):
+ """
+ Returns the skew-symmetric matrix of x
+ param x: 3x1 vector
+ """
+ return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
+
+def get_quaternion_error(curr_quat, des_quat):
+ """
+ Calculates the difference between the current quaternion and the desired quaternion.
+ See Siciliano textbook page 140 Eq 3.91
+
+ param curr_quat: current quaternion
+ param des_quat: desired quaternion
+ return: difference between current quaternion and desired quaternion
+ """
+ return curr_quat[0] * des_quat[1:] - des_quat[0] * curr_quat[1:] - skew(des_quat[1:]) @ curr_quat[1:]
+
+def rotation_distance(p: np.array, q: np.array):
+ """
+ Calculates the rotation angular between two quaternions
+ param p: quaternion
+ param q: quaternion
+ theta: rotation angle between p and q (rad)
+ """
+ assert p.shape == q.shape, "p and q should be quaternion"
+ theta = 2 * np.arccos(abs(p @ q))
+ return theta
+
+
+def rot_to_quat(theta, axis):
+ """
+ Converts rotation angle along an axis to quaternion
+ param theta: rotation angle (rad)
+ param axis: rotation axis
+ return: quaternion
+ """
+ quant = np.zeros(4)
+ quant[0] = np.sin(theta / 2.)
+ quant[1:] = np.cos(theta / 2.) * axis
+ return quant
diff --git a/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py
new file mode 100644
index 0000000..09b2d65
--- /dev/null
+++ b/fancy_gym/envs/mujoco/box_pushing/mp_wrapper.py
@@ -0,0 +1,29 @@
+from typing import Union, Tuple
+
+import numpy as np
+
+from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
+
+
+class MPWrapper(RawInterfaceWrapper):
+
+ # Random x goal + random init pos
+ @property
+ def context_mask(self):
+ return np.hstack([
+ [False] * 7, # joints position
+ [False] * 7, # joints velocity
+ [False] * 3, # position of box
+ [False] * 4, # orientation of box
+ [True] * 3, # position of target
+ [True] * 4, # orientation of target
+ # [True] * 1, # time
+ ])
+
+ @property
+ def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
+ return self.data.qpos[:7].copy()
+
+ @property
+ def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
+ return self.data.qvel[:7].copy()
diff --git a/fancy_gym/examples/example_replanning_envs.py b/fancy_gym/examples/example_replanning_envs.py
new file mode 100644
index 0000000..977ce9e
--- /dev/null
+++ b/fancy_gym/examples/example_replanning_envs.py
@@ -0,0 +1,62 @@
+import fancy_gym
+
+def example_run_replanning_env(env_name="BoxPushingDenseReplanProDMP-v0", seed=1, iterations=1, render=False):
+ env = fancy_gym.make(env_name, seed=seed)
+ env.reset()
+ for i in range(iterations):
+ done = False
+ while done is False:
+ ac = env.action_space.sample()
+ obs, reward, done, info = env.step(ac)
+ if render:
+ env.render(mode="human")
+ if done:
+ env.reset()
+ env.close()
+ del env
+
+def example_custom_replanning_envs(seed=0, iteration=100, render=True):
+ # id for a step-based environment
+ base_env_id = "BoxPushingDense-v0"
+
+ wrappers = [fancy_gym.envs.mujoco.box_pushing.mp_wrapper.MPWrapper]
+
+ trajectory_generator_kwargs = {'trajectory_generator_type': 'prodmp',
+ 'weight_scale': 1}
+ phase_generator_kwargs = {'phase_generator_type': 'exp'}
+ controller_kwargs = {'controller_type': 'velocity'}
+ basis_generator_kwargs = {'basis_generator_type': 'prodmp',
+ 'num_basis': 5}
+
+ # max_planning_times: the maximum number of plans can be generated
+ # replanning_schedule: the trigger for replanning
+ # condition_on_desired: use desired state as the boundary condition for the next plan
+ black_box_kwargs = {'max_planning_times': 4,
+ 'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
+ 'condition_on_desired': True}
+
+ env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
+ traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
+ phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
+ seed=seed)
+ if render:
+ env.render(mode="human")
+
+ obs = env.reset()
+
+ for i in range(iteration):
+ ac = env.action_space.sample()
+ obs, reward, done, info = env.step(ac)
+ if done:
+ env.reset()
+
+ env.close()
+ del env
+
+
+if __name__ == "__main__":
+ # run a registered replanning environment
+ example_run_replanning_env(env_name="BoxPushingDenseReplanProDMP-v0", seed=1, iterations=1, render=False)
+
+ # run a custom replanning environment
+ example_custom_replanning_envs(seed=0, iteration=8, render=True)
\ No newline at end of file
diff --git a/fancy_gym/examples/examples_movement_primitives.py b/fancy_gym/examples/examples_movement_primitives.py
index da7c94d..745e4e8 100644
--- a/fancy_gym/examples/examples_movement_primitives.py
+++ b/fancy_gym/examples/examples_movement_primitives.py
@@ -24,7 +24,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
# number of samples/full trajectories (multiple environment steps)
for i in range(iterations):
- if render and i % 2 == 0:
+ if render and i % 1 == 0:
# This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
# Resetting to no rendering, can be achieved by render(mode=None).
@@ -33,7 +33,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
# Just make sure the correct mode is set before executing the step.
env.render(mode="human")
else:
- env.render(mode=None)
+ env.render()
# Now the action space is not the raw action but the parametrization of the trajectory generator,
# such as a ProMP
@@ -161,6 +161,10 @@ if __name__ == '__main__':
# ProMP
example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
+ example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
+
+ # ProDMP
+ example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
# Altered basis functions
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
diff --git a/fancy_gym/examples/examples_open_ai.py b/fancy_gym/examples/examples_open_ai.py
index a4a162d..789271f 100644
--- a/fancy_gym/examples/examples_open_ai.py
+++ b/fancy_gym/examples/examples_open_ai.py
@@ -22,7 +22,7 @@ def example_mp(env_name, seed=1, render=True):
if render and i % 2 == 0:
env.render(mode="human")
else:
- env.render(mode=None)
+ env.render()
ac = env.action_space.sample()
obs, reward, done, info = env.step(ac)
returns += reward
diff --git a/setup.py b/setup.py
index c029591..40480db 100644
--- a/setup.py
+++ b/setup.py
@@ -18,9 +18,9 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[
setup(
author='Fabian Otto, Onur Celik',
name='fancy_gym',
- version='0.3',
+ version='0.2',
classifiers=[
- 'Development Status :: 4 - Beta',
+ 'Development Status :: 3 - Alpha',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
@@ -34,8 +34,8 @@ setup(
],
extras_require=extras,
install_requires=[
- 'gym[mujoco]<0.25.0,>=0.24.0',
- 'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main'
+ 'gym[mujoco]<0.25.0,>=0.24.1',
+ 'mp_pytorch<=0.1.3'
],
packages=[package for package in find_packages() if package.startswith("fancy_gym")],
package_data={
diff --git a/test/test_black_box.py b/test/test_black_box.py
index d5e3a88..5ade1ae 100644
--- a/test/test_black_box.py
+++ b/test/test_black_box.py
@@ -67,28 +67,32 @@ def test_missing_wrapper(env_id: str):
fancy_gym.make_bb(env_id, [], {}, {}, {}, {}, {})
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
def test_missing_local_state(mp_type: str):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'})
+ {'basis_generator_type': basis_generator_type})
env.reset()
with pytest.raises(NotImplementedError):
env.step(env.action_space.sample())
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
@pytest.mark.parametrize('verbose', [1, 2])
def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]], verbose: int):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+
env_id, wrapper_class = env_wrap
env = fancy_gym.make_bb(env_id, [wrapper_class], {'verbose': verbose},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'})
+ {'basis_generator_type': basis_generator_type})
env.reset()
info_keys = list(env.step(env.action_space.sample())[3].keys())
@@ -104,15 +108,17 @@ def test_verbosity(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]
assert all(e in info_keys for e in mp_keys)
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+
env_id, wrapper_class = env_wrap
env = fancy_gym.make_bb(env_id, [wrapper_class], {},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'})
+ {'basis_generator_type': basis_generator_type})
for _ in range(5):
env.reset()
@@ -121,14 +127,15 @@ def test_length(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapper]]):
assert length == env.spec.max_episode_steps
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('reward_aggregation', [np.sum, np.mean, np.median, lambda x: np.mean(x[::2])])
def test_aggregation(mp_type: str, reward_aggregation: Callable[[np.ndarray], float]):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'reward_aggregation': reward_aggregation},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'})
+ {'basis_generator_type': basis_generator_type})
env.reset()
# ToyEnv only returns 1 as reward
assert env.step(env.action_space.sample())[1] == reward_aggregation(np.ones(50, ))
@@ -149,12 +156,13 @@ def test_context_space(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapp
assert env.observation_space.shape == wrapper.context_mask[wrapper.context_mask].shape
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('num_dof', [0, 1, 2, 5])
@pytest.mark.parametrize('num_basis', [0, 1, 2, 5])
@pytest.mark.parametrize('learn_tau', [True, False])
@pytest.mark.parametrize('learn_delay', [True, False])
def test_action_space(mp_type: str, num_dof: int, num_basis: int, learn_tau: bool, learn_delay: bool):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {},
{'trajectory_generator_type': mp_type,
'action_dim': num_dof
@@ -164,28 +172,29 @@ def test_action_space(mp_type: str, num_dof: int, num_basis: int, learn_tau: boo
'learn_tau': learn_tau,
'learn_delay': learn_delay
},
- {'basis_generator_type': 'rbf',
+ {'basis_generator_type': basis_generator_type,
'num_basis': num_basis
})
base_dims = num_dof * num_basis
- additional_dims = num_dof if mp_type == 'dmp' else 0
+ additional_dims = num_dof if 'dmp' in mp_type else 0
traj_modification_dims = int(learn_tau) + int(learn_delay)
assert env.action_space.shape[0] == base_dims + traj_modification_dims + additional_dims
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('a', [1])
@pytest.mark.parametrize('b', [1.0])
@pytest.mark.parametrize('c', [[1], [1.0], ['str'], [{'a': 'b'}], [np.ones(3, )]])
@pytest.mark.parametrize('d', [{'a': 1}, {1: 2.0}, {'a': [1.0]}, {'a': np.ones(3, )}, {'a': {'a': 'b'}}])
@pytest.mark.parametrize('e', [Object()])
def test_change_env_kwargs(mp_type: str, a: int, b: float, c: list, d: dict, e: Object):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
{'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'},
+ {'basis_generator_type': basis_generator_type},
a=a, b=b, c=c, d=d, e=e
)
assert a is env.a
@@ -196,18 +205,20 @@ def test_change_env_kwargs(mp_type: str, a: int, b: float, c: list, d: dict, e:
assert e is env.e
-@pytest.mark.parametrize('mp_type', ['promp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
@pytest.mark.parametrize('tau', [0.25, 0.5, 0.75, 1])
def test_learn_tau(mp_type: str, tau: float):
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'verbose': 2},
{'trajectory_generator_type': mp_type,
},
{'controller_type': 'motor'},
- {'phase_generator_type': 'linear',
+ {'phase_generator_type': phase_generator_type,
'learn_tau': True,
'learn_delay': False
},
- {'basis_generator_type': 'rbf',
+ {'basis_generator_type': basis_generator_type,
}, seed=SEED)
d = True
@@ -228,26 +239,29 @@ def test_learn_tau(mp_type: str, tau: float):
vel = info['velocities'].flatten()
# Check end is all same (only true for linear basis)
- assert np.all(pos[tau_time_steps:] == pos[-1])
- assert np.all(vel[tau_time_steps:] == vel[-1])
+ if phase_generator_type == "linear":
+ assert np.all(pos[tau_time_steps:] == pos[-1])
+ assert np.all(vel[tau_time_steps:] == vel[-1])
# Check active trajectory section is different to end values
assert np.all(pos[:tau_time_steps - 1] != pos[-1])
assert np.all(vel[:tau_time_steps - 2] != vel[-1])
-
-
-@pytest.mark.parametrize('mp_type', ['promp'])
+#
+#
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
@pytest.mark.parametrize('delay', [0, 0.25, 0.5, 0.75])
def test_learn_delay(mp_type: str, delay: float):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'verbose': 2},
{'trajectory_generator_type': mp_type,
},
{'controller_type': 'motor'},
- {'phase_generator_type': 'linear',
+ {'phase_generator_type': phase_generator_type,
'learn_tau': False,
'learn_delay': True
},
- {'basis_generator_type': 'rbf',
+ {'basis_generator_type': basis_generator_type,
}, seed=SEED)
d = True
@@ -274,21 +288,23 @@ def test_learn_delay(mp_type: str, delay: float):
# Check active trajectory section is different to beginning values
assert np.all(pos[max(1, delay_time_steps):] != pos[0])
assert np.all(vel[max(1, delay_time_steps)] != vel[0])
-
-
-@pytest.mark.parametrize('mp_type', ['promp'])
+#
+#
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
@pytest.mark.parametrize('tau', [0.25, 0.5, 0.75, 1])
@pytest.mark.parametrize('delay', [0.25, 0.5, 0.75, 1])
def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
env = fancy_gym.make_bb('toy-v0', [ToyWrapper], {'verbose': 2},
{'trajectory_generator_type': mp_type,
},
{'controller_type': 'motor'},
- {'phase_generator_type': 'linear',
+ {'phase_generator_type': phase_generator_type,
'learn_tau': True,
'learn_delay': True
},
- {'basis_generator_type': 'rbf',
+ {'basis_generator_type': basis_generator_type,
}, seed=SEED)
if env.spec.max_episode_steps * env.dt < delay + tau:
@@ -315,8 +331,9 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
vel = info['velocities'].flatten()
# Check end is all same (only true for linear basis)
- assert np.all(pos[joint_time_steps:] == pos[-1])
- assert np.all(vel[joint_time_steps:] == vel[-1])
+ if phase_generator_type == "linear":
+ assert np.all(pos[joint_time_steps:] == pos[-1])
+ assert np.all(vel[joint_time_steps:] == vel[-1])
# Check beginning is all same (only true for linear basis)
assert np.all(pos[:delay_time_steps - 1] == pos[0])
@@ -326,4 +343,4 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
active_pos = pos[delay_time_steps: joint_time_steps - 1]
active_vel = vel[delay_time_steps: joint_time_steps - 2]
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
- assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
+ assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
\ No newline at end of file
diff --git a/test/test_replanning_sequencing.py b/test/test_replanning_sequencing.py
index a42bb65..9d04d02 100644
--- a/test/test_replanning_sequencing.py
+++ b/test/test_replanning_sequencing.py
@@ -98,7 +98,7 @@ def test_learn_sub_trajectories(mp_type: str, env_wrap: Tuple[str, Type[RawInter
assert length <= np.round(env.traj_gen.tau.numpy() / env.dt)
-@pytest.mark.parametrize('mp_type', ['promp', 'dmp'])
+@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
@pytest.mark.parametrize('env_wrap', zip(ENV_IDS, WRAPPERS))
@pytest.mark.parametrize('add_time_aware_wrapper_before', [True, False])
@pytest.mark.parametrize('replanning_time', [10, 100, 1000])
@@ -114,11 +114,14 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
replanning_schedule = lambda c_pos, c_vel, obs, c_action, t: t % replanning_time == 0
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if 'dmp' in mp_type else 'linear'
+
env = fancy_gym.make_bb(env_id, [wrapper_class], {'replanning_schedule': replanning_schedule, 'verbose': 2},
{'trajectory_generator_type': mp_type},
{'controller_type': 'motor'},
- {'phase_generator_type': 'exp'},
- {'basis_generator_type': 'rbf'}, seed=SEED)
+ {'phase_generator_type': phase_generator_type},
+ {'basis_generator_type': basis_generator_type}, seed=SEED)
assert env.do_replanning
assert callable(env.replanning_schedule)
@@ -142,3 +145,189 @@ def test_replanning_time(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWra
env.reset()
assert replanning_schedule(None, None, None, None, length)
+
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
+@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
+@pytest.mark.parametrize('sub_segment_steps', [5, 10])
+def test_max_planning_times(mp_type: str, max_planning_times: int, sub_segment_steps: int):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
+ {'max_planning_times': max_planning_times,
+ 'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
+ 'verbose': 2},
+ {'trajectory_generator_type': mp_type,
+ },
+ {'controller_type': 'motor'},
+ {'phase_generator_type': phase_generator_type,
+ 'learn_tau': False,
+ 'learn_delay': False
+ },
+ {'basis_generator_type': basis_generator_type,
+ },
+ seed=SEED)
+ _ = env.reset()
+ d = False
+ planning_times = 0
+ while not d:
+ _, _, d, _ = env.step(env.action_space.sample())
+ planning_times += 1
+ assert planning_times == max_planning_times
+
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
+@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
+@pytest.mark.parametrize('sub_segment_steps', [5, 10])
+@pytest.mark.parametrize('tau', [0.5, 1.0, 1.5, 2.0])
+def test_replanning_with_learn_tau(mp_type: str, max_planning_times: int, sub_segment_steps: int, tau: float):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
+ {'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
+ 'max_planning_times': max_planning_times,
+ 'verbose': 2},
+ {'trajectory_generator_type': mp_type,
+ },
+ {'controller_type': 'motor'},
+ {'phase_generator_type': phase_generator_type,
+ 'learn_tau': True,
+ 'learn_delay': False
+ },
+ {'basis_generator_type': basis_generator_type,
+ },
+ seed=SEED)
+ _ = env.reset()
+ d = False
+ planning_times = 0
+ while not d:
+ action = env.action_space.sample()
+ action[0] = tau
+ _, _, d, info = env.step(action)
+ planning_times += 1
+ assert planning_times == max_planning_times
+
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
+@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
+@pytest.mark.parametrize('sub_segment_steps', [5, 10])
+@pytest.mark.parametrize('delay', [0.1, 0.25, 0.5, 0.75])
+def test_replanning_with_learn_delay(mp_type: str, max_planning_times: int, sub_segment_steps: int, delay: float):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
+ {'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
+ 'max_planning_times': max_planning_times,
+ 'verbose': 2},
+ {'trajectory_generator_type': mp_type,
+ },
+ {'controller_type': 'motor'},
+ {'phase_generator_type': phase_generator_type,
+ 'learn_tau': False,
+ 'learn_delay': True
+ },
+ {'basis_generator_type': basis_generator_type,
+ },
+ seed=SEED)
+ _ = env.reset()
+ d = False
+ planning_times = 0
+ while not d:
+ action = env.action_space.sample()
+ action[0] = delay
+ _, _, d, info = env.step(action)
+
+ delay_time_steps = int(np.round(delay / env.dt))
+ pos = info['positions'].flatten()
+ vel = info['velocities'].flatten()
+
+ # Check beginning is all same (only true for linear basis)
+ if planning_times == 0:
+ assert np.all(pos[:max(1, delay_time_steps - 1)] == pos[0])
+ assert np.all(vel[:max(1, delay_time_steps - 2)] == vel[0])
+
+ # only valid when delay < sub_segment_steps
+ elif planning_times > 0 and delay_time_steps < sub_segment_steps:
+ assert np.all(pos[1:max(1, delay_time_steps - 1)] != pos[0])
+ assert np.all(vel[1:max(1, delay_time_steps - 2)] != vel[0])
+
+ # Check active trajectory section is different to beginning values
+ assert np.all(pos[max(1, delay_time_steps):] != pos[0])
+ assert np.all(vel[max(1, delay_time_steps)] != vel[0])
+
+ planning_times += 1
+
+ assert planning_times == max_planning_times
+
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
+@pytest.mark.parametrize('max_planning_times', [1, 2, 3])
+@pytest.mark.parametrize('sub_segment_steps', [5, 10, 15])
+@pytest.mark.parametrize('delay', [0, 0.25, 0.5, 0.75])
+@pytest.mark.parametrize('tau', [0.5, 0.75, 1.0])
+def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: int, sub_segment_steps: int,
+ delay: float, tau: float):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
+ {'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
+ 'max_planning_times': max_planning_times,
+ 'verbose': 2},
+ {'trajectory_generator_type': mp_type,
+ },
+ {'controller_type': 'motor'},
+ {'phase_generator_type': phase_generator_type,
+ 'learn_tau': True,
+ 'learn_delay': True
+ },
+ {'basis_generator_type': basis_generator_type,
+ },
+ seed=SEED)
+ _ = env.reset()
+ d = False
+ planning_times = 0
+ while not d:
+ action = env.action_space.sample()
+ action[0] = tau
+ action[1] = delay
+ _, _, d, info = env.step(action)
+
+ delay_time_steps = int(np.round(delay / env.dt))
+
+ pos = info['positions'].flatten()
+ vel = info['velocities'].flatten()
+
+ # Delay only applies to first planning time
+ if planning_times == 0:
+ # Check delay is applied
+ assert np.all(pos[:max(1, delay_time_steps - 1)] == pos[0])
+ assert np.all(vel[:max(1, delay_time_steps - 2)] == vel[0])
+ # Check active trajectory section is different to beginning values
+ assert np.all(pos[max(1, delay_time_steps):] != pos[0])
+ assert np.all(vel[max(1, delay_time_steps)] != vel[0])
+
+ planning_times += 1
+
+ assert planning_times == max_planning_times
+
+@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
+@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
+@pytest.mark.parametrize('sub_segment_steps', [5, 10])
+def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_steps: int):
+ basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
+ phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
+ env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
+ {'max_planning_times': max_planning_times,
+ 'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
+ 'verbose': 2},
+ {'trajectory_generator_type': mp_type,
+ },
+ {'controller_type': 'motor'},
+ {'phase_generator_type': phase_generator_type,
+ 'learn_tau': False,
+ 'learn_delay': False
+ },
+ {'basis_generator_type': basis_generator_type,
+ },
+ seed=SEED)
+ _ = env.reset()
+ d = False
+ for i in range(max_planning_times):
+ _, _, d, _ = env.step(env.action_space.sample())
+ assert d
diff --git a/test/utils.py b/test/utils.py
index 7ed8d61..dff2292 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -49,8 +49,8 @@ def run_env(env_id, iterations=None, seed=0, render=False):
if done:
break
-
- assert done, "Done flag is not True after end of episode."
+ if not hasattr(env, "replanning_schedule"):
+ assert done, "Done flag is not True after end of episode."
observations.append(obs)
env.close()
del env