restructuring
This commit is contained in:
		
							parent
							
								
									8fe6a83271
								
							
						
					
					
						commit
						02b8a65bab
					
				| @ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper] | |||||||
| mp_kwargs = {...} | mp_kwargs = {...} | ||||||
| kwargs = {...} | kwargs = {...} | ||||||
| env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs) | env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs) | ||||||
| # OR for a deterministic ProMP (other mp_kwargs are required): | # OR for a deterministic ProMP (other traj_gen_kwargs are required): | ||||||
| # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) | # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args) | ||||||
| 
 | 
 | ||||||
| rewards = 0 | rewards = 0 | ||||||
| obs = env.reset() | obs = env.reset() | ||||||
|  | |||||||
| @ -346,7 +346,7 @@ for _v in _versions: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": f"alr_envs:{_v}", |             "name": f"alr_envs:{_v}", | ||||||
|             "wrappers": [classic_control.simple_reacher.MPWrapper], |             "wrappers": [classic_control.simple_reacher.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 2 if "long" not in _v.lower() else 5, |                 "num_dof": 2 if "long" not in _v.lower() else 5, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 2, |                 "duration": 2, | ||||||
| @ -386,7 +386,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "alr_envs:ViaPointReacher-v0", |         "name": "alr_envs:ViaPointReacher-v0", | ||||||
|         "wrappers": [classic_control.viapoint_reacher.MPWrapper], |         "wrappers": [classic_control.viapoint_reacher.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 5, |             "num_dof": 5, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
| @ -424,7 +424,7 @@ for _v in _versions: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": f"alr_envs:HoleReacher-{_v}", |             "name": f"alr_envs:HoleReacher-{_v}", | ||||||
|             "wrappers": [classic_control.hole_reacher.MPWrapper], |             "wrappers": [classic_control.hole_reacher.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 5, |                 "num_dof": 5, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 2, |                 "duration": 2, | ||||||
| @ -467,7 +467,7 @@ for _v in _versions: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": f"alr_envs:{_v}", |             "name": f"alr_envs:{_v}", | ||||||
|             "wrappers": [mujoco.reacher.MPWrapper], |             "wrappers": [mujoco.reacher.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 5 if "long" not in _v.lower() else 7, |                 "num_dof": 5 if "long" not in _v.lower() else 7, | ||||||
|                 "num_basis": 2, |                 "num_basis": 2, | ||||||
|                 "duration": 4, |                 "duration": 4, | ||||||
|  | |||||||
| @ -1,12 +1,13 @@ | |||||||
| from alr_envs.mp.episodic_wrapper import EpisodicWrapper | from alr_envs.mp.black_box_wrapper import BlackBoxWrapper | ||||||
| from typing import Union, Tuple | from typing import Union, Tuple | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
|  | from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class NewMPWrapper(EpisodicWrapper): | class NewMPWrapper(RawInterfaceWrapper): | ||||||
| 
 | 
 | ||||||
|     def set_active_obs(self): |     def get_context_mask(self): | ||||||
|         return np.hstack([ |         return np.hstack([ | ||||||
|             [False] * 111, # ant has 111 dimensional observation space !! |             [False] * 111, # ant has 111 dimensional observation space !! | ||||||
|             [True] # goal height |             [True] # goal height | ||||||
|  | |||||||
| @ -1,15 +1,11 @@ | |||||||
| from typing import Tuple, Union | from typing import Union, Tuple | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| from alr_envs.mp.episodic_wrapper import EpisodicWrapper | from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class NewMPWrapper(EpisodicWrapper): | class NewMPWrapper(RawInterfaceWrapper): | ||||||
| 
 |  | ||||||
|     # def __init__(self, replanning_model): |  | ||||||
|     #     self.replanning_model = replanning_model |  | ||||||
| 
 |  | ||||||
|     @property |     @property | ||||||
|     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|         return self.env.sim.data.qpos[0:7].copy() |         return self.env.sim.data.qpos[0:7].copy() | ||||||
| @ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper): | |||||||
|     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|         return self.env.sim.data.qvel[0:7].copy() |         return self.env.sim.data.qvel[0:7].copy() | ||||||
| 
 | 
 | ||||||
|     def set_active_obs(self): |     def get_context_mask(self): | ||||||
|         return np.hstack([ |         return np.hstack([ | ||||||
|             [False] * 7,  # cos |             [False] * 7,  # cos | ||||||
|             [False] * 7,  # sin |             [False] * 7,  # sin | ||||||
| @ -27,12 +23,7 @@ class NewMPWrapper(EpisodicWrapper): | |||||||
|             [False] * 3,  # cup_goal_diff_top |             [False] * 3,  # cup_goal_diff_top | ||||||
|             [True] * 2,  # xy position of cup |             [True] * 2,  # xy position of cup | ||||||
|             [False]  # env steps |             [False]  # env steps | ||||||
|             ]) |         ]) | ||||||
| 
 |  | ||||||
|     def do_replanning(self, pos, vel, s, a, t, last_replan_step): |  | ||||||
|         return False |  | ||||||
|         # const = np.arange(0, 1000, 10) |  | ||||||
|         # return bool(self.replanning_model(s)) |  | ||||||
| 
 | 
 | ||||||
|     def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: |     def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: | ||||||
|         if self.mp.learn_tau: |         if self.mp.learn_tau: | ||||||
|  | |||||||
| @ -1,9 +1,9 @@ | |||||||
| from alr_envs.mp.episodic_wrapper import EpisodicWrapper | from alr_envs.mp.black_box_wrapper import BlackBoxWrapper | ||||||
| from typing import Union, Tuple | from typing import Union, Tuple | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class NewMPWrapper(EpisodicWrapper): | class NewMPWrapper(BlackBoxWrapper): | ||||||
|     @property |     @property | ||||||
|     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|         return self.env.sim.data.qpos[3:6].copy() |         return self.env.sim.data.qpos[3:6].copy() | ||||||
| @ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper): | |||||||
|     #     ]) |     #     ]) | ||||||
| 
 | 
 | ||||||
|     # Random x goal + random init pos |     # Random x goal + random init pos | ||||||
|     def set_active_obs(self): |     def get_context_mask(self): | ||||||
|         return np.hstack([ |         return np.hstack([ | ||||||
|                 [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position |                 [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position | ||||||
|                 [True] * 3,    # set to true if randomize initial pos |                 [True] * 3,    # set to true if randomize initial pos | ||||||
| @ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class NewHighCtxtMPWrapper(NewMPWrapper): | class NewHighCtxtMPWrapper(NewMPWrapper): | ||||||
|     def set_active_obs(self): |     def get_context_mask(self): | ||||||
|         return np.hstack([ |         return np.hstack([ | ||||||
|             [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position |             [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position | ||||||
|             [True] * 3,  # set to true if randomize initial pos |             [True] * 3,  # set to true if randomize initial pos | ||||||
|  | |||||||
| @ -149,4 +149,4 @@ if __name__ == '__main__': | |||||||
|         if d: |         if d: | ||||||
|             env.reset() |             env.reset() | ||||||
| 
 | 
 | ||||||
|     env.close() |     env.close() | ||||||
|  | |||||||
| @ -1,9 +1,9 @@ | |||||||
| from alr_envs.mp.episodic_wrapper import EpisodicWrapper | from alr_envs.mp.black_box_wrapper import BlackBoxWrapper | ||||||
| from typing import Union, Tuple | from typing import Union, Tuple | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MPWrapper(EpisodicWrapper): | class MPWrapper(BlackBoxWrapper): | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
| @ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper): | |||||||
|     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: |     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|         return self.env.sim.data.qvel.flat[:self.env.n_links] |         return self.env.sim.data.qvel.flat[:self.env.n_links] | ||||||
| 
 | 
 | ||||||
|     def set_active_obs(self): |     def get_context_mask(self): | ||||||
|         return np.concatenate([ |         return np.concatenate([ | ||||||
|             [False] * self.env.n_links,  # cos |             [False] * self.env.n_links,  # cos | ||||||
|             [False] * self.env.n_links,  # sin |             [False] * self.env.n_links,  # sin | ||||||
|  | |||||||
| @ -15,7 +15,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.ball_in_cup.MPWrapper], |         "wrappers": [suite.ball_in_cup.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -41,7 +41,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.ball_in_cup.MPWrapper], |         "wrappers": [suite.ball_in_cup.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -65,7 +65,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.reacher.MPWrapper], |         "wrappers": [suite.reacher.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -92,7 +92,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.reacher.MPWrapper], |         "wrappers": [suite.reacher.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -117,7 +117,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.reacher.MPWrapper], |         "wrappers": [suite.reacher.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -144,7 +144,7 @@ register( | |||||||
|         "time_limit": 20, |         "time_limit": 20, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.reacher.MPWrapper], |         "wrappers": [suite.reacher.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 20, |             "duration": 20, | ||||||
| @ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks: | |||||||
|             "camera_id": 0, |             "camera_id": 0, | ||||||
|             "episode_length": 1000, |             "episode_length": 1000, | ||||||
|             "wrappers": [suite.cartpole.MPWrapper], |             "wrappers": [suite.cartpole.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 1, |                 "num_dof": 1, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 10, |                 "duration": 10, | ||||||
| @ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks: | |||||||
|             "camera_id": 0, |             "camera_id": 0, | ||||||
|             "episode_length": 1000, |             "episode_length": 1000, | ||||||
|             "wrappers": [suite.cartpole.MPWrapper], |             "wrappers": [suite.cartpole.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 1, |                 "num_dof": 1, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 10, |                 "duration": 10, | ||||||
| @ -230,7 +230,7 @@ register( | |||||||
|         "camera_id": 0, |         "camera_id": 0, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.cartpole.TwoPolesMPWrapper], |         "wrappers": [suite.cartpole.TwoPolesMPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
| @ -259,7 +259,7 @@ register( | |||||||
|         "camera_id": 0, |         "camera_id": 0, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.cartpole.TwoPolesMPWrapper], |         "wrappers": [suite.cartpole.TwoPolesMPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
| @ -286,7 +286,7 @@ register( | |||||||
|         "camera_id": 0, |         "camera_id": 0, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.cartpole.ThreePolesMPWrapper], |         "wrappers": [suite.cartpole.ThreePolesMPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
| @ -315,7 +315,7 @@ register( | |||||||
|         "camera_id": 0, |         "camera_id": 0, | ||||||
|         "episode_length": 1000, |         "episode_length": 1000, | ||||||
|         "wrappers": [suite.cartpole.ThreePolesMPWrapper], |         "wrappers": [suite.cartpole.ThreePolesMPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
| @ -342,7 +342,7 @@ register( | |||||||
|         # "time_limit": 1, |         # "time_limit": 1, | ||||||
|         "episode_length": 250, |         "episode_length": 250, | ||||||
|         "wrappers": [manipulation.reach_site.MPWrapper], |         "wrappers": [manipulation.reach_site.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 9, |             "num_dof": 9, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
| @ -365,7 +365,7 @@ register( | |||||||
|         # "time_limit": 1, |         # "time_limit": 1, | ||||||
|         "episode_length": 250, |         "episode_length": 250, | ||||||
|         "wrappers": [manipulation.reach_site.MPWrapper], |         "wrappers": [manipulation.reach_site.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 9, |             "num_dof": 9, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 10, |             "duration": 10, | ||||||
|  | |||||||
| @ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): | |||||||
|         "learn_goal": True,  # learn the goal position (recommended) |         "learn_goal": True,  # learn the goal position (recommended) | ||||||
|         "alpha_phase": 2, |         "alpha_phase": 2, | ||||||
|         "bandwidth_factor": 2, |         "bandwidth_factor": 2, | ||||||
|         "policy_type": "motor",  # controller type, 'velocity', 'position', and 'motor' (torque control) |         "policy_type": "motor",  # tracking_controller type, 'velocity', 'position', and 'motor' (torque control) | ||||||
|         "weights_scale": 1,  # scaling of MP weights |         "weights_scale": 1,  # scaling of MP weights | ||||||
|         "goal_scale": 1,  # scaling of learned goal position |         "goal_scale": 1,  # scaling of learned goal position | ||||||
|         "policy_kwargs": {  # only required for torque control/PD-Controller |         "policy_kwargs": {  # only required for torque control/PD-Controller | ||||||
| @ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): | |||||||
|         # "frame_skip": 1 |         # "frame_skip": 1 | ||||||
|     } |     } | ||||||
|     env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) |     env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) | ||||||
|     # OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples): |     # OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples): | ||||||
|     # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) |     # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args) | ||||||
| 
 | 
 | ||||||
|     # This renders the full MP trajectory |     # This renders the full MP trajectory | ||||||
|     # It is only required to call render() once in the beginning, which renders every consecutive trajectory. |     # It is only required to call render() once in the beginning, which renders every consecutive trajectory. | ||||||
|  | |||||||
| @ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): | |||||||
|         "width": 0.025,  # width of the basis functions |         "width": 0.025,  # width of the basis functions | ||||||
|         "zero_start": True,  # start from current environment position if True |         "zero_start": True,  # start from current environment position if True | ||||||
|         "weights_scale": 1,  # scaling of MP weights |         "weights_scale": 1,  # scaling of MP weights | ||||||
|         "policy_type": "metaworld",  # custom controller type for metaworld environments |         "policy_type": "metaworld",  # custom tracking_controller type for metaworld environments | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) |     env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) | ||||||
|     # OR for a DMP (other mp_kwargs are required, see dmc_examples): |     # OR for a DMP (other traj_gen_kwargs are required, see dmc_examples): | ||||||
|     # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) |     # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs) | ||||||
| 
 | 
 | ||||||
|     # This renders the full MP trajectory |     # This renders the full MP trajectory | ||||||
|     # It is only required to call render() once in the beginning, which renders every consecutive trajectory. |     # It is only required to call render() once in the beginning, which renders every consecutive trajectory. | ||||||
|  | |||||||
| @ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations= | |||||||
|     Returns: |     Returns: | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|     # Changing the mp_kwargs is possible by providing them to gym. |     # Changing the traj_gen_kwargs is possible by providing them to gym. | ||||||
|     # E.g. here by providing way to many basis functions |     # E.g. here by providing way to many basis functions | ||||||
|     mp_kwargs = { |     mp_kwargs = { | ||||||
|         "num_dof": 5, |         "num_dof": 5, | ||||||
| @ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): | |||||||
|     } |     } | ||||||
|     env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) |     env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) | ||||||
|     # OR for a deterministic ProMP: |     # OR for a deterministic ProMP: | ||||||
|     # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) |     # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs) | ||||||
| 
 | 
 | ||||||
|     if render: |     if render: | ||||||
|         env.render(mode="human") |         env.render(mode="human") | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ import alr_envs | |||||||
| def example_mp(env_name, seed=1): | def example_mp(env_name, seed=1): | ||||||
|     """ |     """ | ||||||
|     Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. |     Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. | ||||||
|     For more information on motion primitive specific stuff, look at the mp examples. |     For more information on motion primitive specific stuff, look at the trajectory_generator examples. | ||||||
|     Args: |     Args: | ||||||
|         env_name: ProMP env_id |         env_name: ProMP env_id | ||||||
|         seed: seed |         seed: seed | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env | |||||||
| 
 | 
 | ||||||
| def visualize(env): | def visualize(env): | ||||||
|     t = env.t |     t = env.t | ||||||
|     pos_features = env.mp.basis_generator.basis(t) |     pos_features = env.trajectory_generator.basis_generator.basis(t) | ||||||
|     plt.plot(t, pos_features) |     plt.plot(t, pos_features) | ||||||
|     plt.show() |     plt.show() | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ for _task in _goal_change_envs: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": _task, |             "name": _task, | ||||||
|             "wrappers": [goal_change_mp_wrapper.MPWrapper], |             "wrappers": [goal_change_mp_wrapper.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 4, |                 "num_dof": 4, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 6.25, |                 "duration": 6.25, | ||||||
| @ -42,7 +42,7 @@ for _task in _object_change_envs: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": _task, |             "name": _task, | ||||||
|             "wrappers": [object_change_mp_wrapper.MPWrapper], |             "wrappers": [object_change_mp_wrapper.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 4, |                 "num_dof": 4, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 6.25, |                 "duration": 6.25, | ||||||
| @ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": _task, |             "name": _task, | ||||||
|             "wrappers": [goal_object_change_mp_wrapper.MPWrapper], |             "wrappers": [goal_object_change_mp_wrapper.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 4, |                 "num_dof": 4, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 6.25, |                 "duration": 6.25, | ||||||
| @ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs: | |||||||
|         kwargs={ |         kwargs={ | ||||||
|             "name": _task, |             "name": _task, | ||||||
|             "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper], |             "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper], | ||||||
|             "mp_kwargs": { |             "traj_gen_kwargs": { | ||||||
|                 "num_dof": 4, |                 "num_dof": 4, | ||||||
|                 "num_basis": 5, |                 "num_basis": 5, | ||||||
|                 "duration": 6.25, |                 "duration": 6.25, | ||||||
|  | |||||||
| @ -1,5 +1,5 @@ | |||||||
| from abc import ABC, abstractmethod | from abc import ABC | ||||||
| from typing import Union, Tuple | from typing import Tuple | ||||||
| 
 | 
 | ||||||
| import gym | import gym | ||||||
| import numpy as np | import numpy as np | ||||||
| @ -7,77 +7,77 @@ from gym import spaces | |||||||
| from mp_pytorch.mp.mp_interfaces import MPInterface | from mp_pytorch.mp.mp_interfaces import MPInterface | ||||||
| 
 | 
 | ||||||
| from alr_envs.mp.controllers.base_controller import BaseController | from alr_envs.mp.controllers.base_controller import BaseController | ||||||
|  | from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): | class BlackBoxWrapper(gym.ObservationWrapper, ABC): | ||||||
|     """ |  | ||||||
|     Base class for movement primitive based gym.Wrapper implementations. |  | ||||||
| 
 | 
 | ||||||
|     Args: |     def __init__(self, | ||||||
|         env: The (wrapped) environment this wrapper is applied on |                  env: RawInterfaceWrapper, | ||||||
|         num_dof: Dimension of the action space of the wrapped env |                  trajectory_generator: MPInterface, tracking_controller: BaseController, | ||||||
|         num_basis: Number of basis functions per dof |                  duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum): | ||||||
|         duration: Length of the trajectory of the movement primitive in seconds |         """ | ||||||
|         controller: Type or object defining the policy that is used to generate action based on the trajectory |         gym.Wrapper for leveraging a black box approach with a trajectory generator. | ||||||
|         weight_scale: Scaling parameter for the actions given to this wrapper |  | ||||||
|         render_mode: Equivalent to gym render mode |  | ||||||
|     """ |  | ||||||
| 
 | 
 | ||||||
|     def __init__( |         Args: | ||||||
|             self, |             env: The (wrapped) environment this wrapper is applied on | ||||||
|             env: gym.Env, |             trajectory_generator: Generates the full or partial trajectory | ||||||
|             mp: MPInterface, |             tracking_controller: Translates the desired trajectory to raw action sequences | ||||||
|             controller: BaseController, |             duration: Length of the trajectory of the movement primitive in seconds | ||||||
|             duration: float, |             verbose: level of detail for returned values in info dict. | ||||||
|             render_mode: str = None, |             reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory | ||||||
|             verbose: int = 1, |                 reward, default summation over all values. | ||||||
|             weight_scale: float = 1, |         """ | ||||||
|             sequencing=True, |  | ||||||
|             reward_aggregation=np.mean, |  | ||||||
|             ): |  | ||||||
|         super().__init__() |         super().__init__() | ||||||
| 
 | 
 | ||||||
|         self.env = env |         self.env = env | ||||||
|         try: |  | ||||||
|             self.dt = env.dt |  | ||||||
|         except AttributeError: |  | ||||||
|             raise AttributeError("step based environment needs to have a function 'dt' ") |  | ||||||
|         self.duration = duration |         self.duration = duration | ||||||
|         self.traj_steps = int(duration / self.dt) |         self.traj_steps = int(duration / self.dt) | ||||||
|         self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps |         self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps | ||||||
|         # duration = self.env.max_episode_steps * self.dt |         # duration = self.env.max_episode_steps * self.dt | ||||||
| 
 | 
 | ||||||
|         self.mp = mp |         # trajectory generation | ||||||
|         self.env = env |         self.trajectory_generator = trajectory_generator | ||||||
|         self.controller = controller |         self.tracking_controller = tracking_controller | ||||||
|         self.weight_scale = weight_scale |         # self.weight_scale = weight_scale | ||||||
| 
 |  | ||||||
|         # rendering |  | ||||||
|         self.render_mode = render_mode |  | ||||||
|         self.render_kwargs = {} |  | ||||||
|         self.time_steps = np.linspace(0, self.duration, self.traj_steps) |         self.time_steps = np.linspace(0, self.duration, self.traj_steps) | ||||||
|         self.mp.set_mp_times(self.time_steps) |         self.trajectory_generator.set_mp_times(self.time_steps) | ||||||
|         # self.mp.set_mp_duration(self.time_steps, dt) |         # self.trajectory_generator.set_mp_duration(self.time_steps, dt) | ||||||
|         # action_bounds = np.inf * np.ones((np.prod(self.mp.num_params))) |         # action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params))) | ||||||
|         self.mp_action_space = self.get_mp_action_space() |         self.reward_aggregation = reward_aggregation | ||||||
| 
 | 
 | ||||||
|  |         # spaces | ||||||
|  |         self.mp_action_space = self.get_mp_action_space() | ||||||
|         self.action_space = self.get_action_space() |         self.action_space = self.get_action_space() | ||||||
|         self.active_obs = self.set_active_obs() |         self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask], | ||||||
|         self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs], |                                             high=self.env.observation_space.high[self.env.context_mask], | ||||||
|                                             high=self.env.observation_space.high[self.active_obs], |  | ||||||
|                                             dtype=self.env.observation_space.dtype) |                                             dtype=self.env.observation_space.dtype) | ||||||
| 
 | 
 | ||||||
|  |         # rendering | ||||||
|  |         self.render_mode = None | ||||||
|  |         self.render_kwargs = {} | ||||||
|  | 
 | ||||||
|         self.verbose = verbose |         self.verbose = verbose | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     def dt(self): | ||||||
|  |         return self.env.dt | ||||||
|  | 
 | ||||||
|  |     def observation(self, observation): | ||||||
|  |         return observation[self.env.context_mask] | ||||||
|  | 
 | ||||||
|     def get_trajectory(self, action: np.ndarray) -> Tuple: |     def get_trajectory(self, action: np.ndarray) -> Tuple: | ||||||
|         # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at |         # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at | ||||||
|         #  the beginning of the array. |         #  the beginning of the array. | ||||||
|         ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay) |         # ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay) | ||||||
|         scaled_mp_params = action.copy() |         # scaled_mp_params = action.copy() | ||||||
|         scaled_mp_params[ignore_indices:] *= self.weight_scale |         # scaled_mp_params[ignore_indices:] *= self.weight_scale | ||||||
|         self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high)) | 
 | ||||||
|         self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel) |         clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high) | ||||||
|         traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True) |         self.trajectory_generator.set_params(clipped_params) | ||||||
|  |         self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, | ||||||
|  |                                                           bc_vel=self.current_vel) | ||||||
|  |         traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True) | ||||||
|         trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] |         trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] | ||||||
| 
 | 
 | ||||||
|         trajectory = trajectory_tensor.numpy() |         trajectory = trajectory_tensor.numpy() | ||||||
| @ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): | |||||||
|         # TODO: Do we need this or does mp_pytorch have this? |         # TODO: Do we need this or does mp_pytorch have this? | ||||||
|         if self.post_traj_steps > 0: |         if self.post_traj_steps > 0: | ||||||
|             trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])]) |             trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])]) | ||||||
|             velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))]) |             velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))]) | ||||||
| 
 | 
 | ||||||
|         return trajectory, velocity |         return trajectory, velocity | ||||||
| 
 | 
 | ||||||
|     def get_mp_action_space(self): |     def get_mp_action_space(self): | ||||||
|         """This function can be used to set up an individual space for the parameters of the mp.""" |         """This function can be used to set up an individual space for the parameters of the trajectory_generator.""" | ||||||
|         min_action_bounds, max_action_bounds = self.mp.get_param_bounds() |         min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds() | ||||||
|         mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), |         mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), | ||||||
|                                          dtype=np.float32) |                                          dtype=np.float32) | ||||||
|         return mp_action_space |         return mp_action_space | ||||||
| @ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): | |||||||
|         except AttributeError: |         except AttributeError: | ||||||
|             return self.get_mp_action_space() |             return self.get_mp_action_space() | ||||||
| 
 | 
 | ||||||
|     def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: |  | ||||||
|         """ |  | ||||||
|         Used to extract the parameters for the motion primitive and other parameters from an action array which might |  | ||||||
|         include other actions like ball releasing time for the beer pong environment. |  | ||||||
|         This only needs to be overwritten if the action space is modified. |  | ||||||
|         Args: |  | ||||||
|             action: a vector instance of the whole action space, includes mp parameters and additional parameters if |  | ||||||
|             specified, else only mp parameters |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             Tuple: mp_arguments and other arguments |  | ||||||
|         """ |  | ||||||
|         return action, None |  | ||||||
| 
 |  | ||||||
|     def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[ |  | ||||||
|         np.ndarray]: |  | ||||||
|         """ |  | ||||||
|         This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the |  | ||||||
|         Beerpong env. The parameters used should not be part of the motion primitive parameters. |  | ||||||
|         Returns step_action by default, can be overwritten in individual mp_wrappers. |  | ||||||
|         Args: |  | ||||||
|             t: the current time step of the episode |  | ||||||
|             env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback |  | ||||||
|             (e.g. ball release time in Beer Pong) |  | ||||||
|             step_action: the current step-based action |  | ||||||
| 
 |  | ||||||
|         Returns: |  | ||||||
|             modified step action |  | ||||||
|         """ |  | ||||||
|         return step_action |  | ||||||
| 
 |  | ||||||
|     @abstractmethod |  | ||||||
|     def set_active_obs(self) -> np.ndarray: |  | ||||||
|         """ |  | ||||||
|         This function defines the contexts. The contexts are defined as specific observations. |  | ||||||
|         Returns: |  | ||||||
|             boolearn array representing the indices of the observations |  | ||||||
| 
 |  | ||||||
|         """ |  | ||||||
|         return np.ones(self.env.observation_space.shape[0], dtype=bool) |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     @abstractmethod |  | ||||||
|     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: |  | ||||||
|         """ |  | ||||||
|             Returns the current position of the action/control dimension. |  | ||||||
|             The dimensionality has to match the action/control dimension. |  | ||||||
|             This is not required when exclusively using velocity control, |  | ||||||
|             it should, however, be implemented regardless. |  | ||||||
|             E.g. The joint positions that are directly or indirectly controlled by the action. |  | ||||||
|         """ |  | ||||||
|         raise NotImplementedError() |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     @abstractmethod |  | ||||||
|     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: |  | ||||||
|         """ |  | ||||||
|             Returns the current velocity of the action/control dimension. |  | ||||||
|             The dimensionality has to match the action/control dimension. |  | ||||||
|             This is not required when exclusively using position control, |  | ||||||
|             it should, however, be implemented regardless. |  | ||||||
|             E.g. The joint velocities that are directly or indirectly controlled by the action. |  | ||||||
|         """ |  | ||||||
|         raise NotImplementedError() |  | ||||||
| 
 |  | ||||||
|     def step(self, action: np.ndarray): |     def step(self, action: np.ndarray): | ||||||
|         """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" |         """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" | ||||||
|         # TODO: Think about sequencing |         # TODO: Think about sequencing | ||||||
| @ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): | |||||||
| 
 | 
 | ||||||
|         # TODO |         # TODO | ||||||
|         # self.time_steps = np.linspace(0, learned_duration, self.traj_steps) |         # self.time_steps = np.linspace(0, learned_duration, self.traj_steps) | ||||||
|         # self.mp.set_mp_times(self.time_steps) |         # self.trajectory_generator.set_mp_times(self.time_steps) | ||||||
| 
 | 
 | ||||||
|         trajectory_length = len(trajectory) |         trajectory_length = len(trajectory) | ||||||
|  |         rewards = np.zeros(shape=(trajectory_length,)) | ||||||
|         if self.verbose >= 2: |         if self.verbose >= 2: | ||||||
|             actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) |             actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) | ||||||
|             observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape, |             observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape, | ||||||
|                                     dtype=self.env.observation_space.dtype) |                                     dtype=self.env.observation_space.dtype) | ||||||
|             rewards = np.zeros(shape=(trajectory_length,)) |  | ||||||
|         trajectory_return = 0 |  | ||||||
| 
 | 
 | ||||||
|         infos = dict() |         infos = dict() | ||||||
|  |         done = False | ||||||
| 
 | 
 | ||||||
|         for t, pos_vel in enumerate(zip(trajectory, velocity)): |         for t, pos_vel in enumerate(zip(trajectory, velocity)): | ||||||
|             step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel) |             step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, | ||||||
|  |                                                               self.current_vel) | ||||||
|             step_action = self._step_callback(t, env_spec_params, step_action)  # include possible callback info |             step_action = self._step_callback(t, env_spec_params, step_action)  # include possible callback info | ||||||
|             c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) |             c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) | ||||||
|             # print('step/clipped action ratio: ', step_action/c_action) |             # print('step/clipped action ratio: ', step_action/c_action) | ||||||
|             obs, c_reward, done, info = self.env.step(c_action) |             obs, c_reward, done, info = self.env.step(c_action) | ||||||
|  |             rewards[t] = c_reward | ||||||
|  | 
 | ||||||
|             if self.verbose >= 2: |             if self.verbose >= 2: | ||||||
|                 actions[t, :] = c_action |                 actions[t, :] = c_action | ||||||
|                 rewards[t] = c_reward |  | ||||||
|                 observations[t, :] = obs |                 observations[t, :] = obs | ||||||
|             trajectory_return += c_reward | 
 | ||||||
|             for k, v in info.items(): |             for k, v in info.items(): | ||||||
|                 elems = infos.get(k, [None] * trajectory_length) |                 elems = infos.get(k, [None] * trajectory_length) | ||||||
|                 elems[t] = v |                 elems[t] = v | ||||||
|                 infos[k] = elems |                 infos[k] = elems | ||||||
|             # infos['step_infos'].append(info) | 
 | ||||||
|             if self.render_mode: |             if self.render_mode is not None: | ||||||
|                 self.render(mode=self.render_mode, **self.render_kwargs) |                 self.render(mode=self.render_mode, **self.render_kwargs) | ||||||
|             if done or do_replanning(kwargs): | 
 | ||||||
|  |             if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t): | ||||||
|                 break |                 break | ||||||
|  | 
 | ||||||
|         infos.update({k: v[:t + 1] for k, v in infos.items()}) |         infos.update({k: v[:t + 1] for k, v in infos.items()}) | ||||||
|  | 
 | ||||||
|         if self.verbose >= 2: |         if self.verbose >= 2: | ||||||
|             infos['trajectory'] = trajectory |             infos['trajectory'] = trajectory | ||||||
|             infos['step_actions'] = actions[:t + 1] |             infos['step_actions'] = actions[:t + 1] | ||||||
|             infos['step_observations'] = observations[:t + 1] |             infos['step_observations'] = observations[:t + 1] | ||||||
|             infos['step_rewards'] = rewards[:t + 1] |             infos['step_rewards'] = rewards[:t + 1] | ||||||
|  | 
 | ||||||
|         infos['trajectory_length'] = t + 1 |         infos['trajectory_length'] = t + 1 | ||||||
|         done = True |         trajectory_return = self.reward_aggregation(rewards[:t + 1]) | ||||||
|         return self.get_observation_from_step(obs), trajectory_return, done, infos |         return self.get_observation_from_step(obs), trajectory_return, done, infos | ||||||
| 
 | 
 | ||||||
|     def reset(self): |     def reset(self): | ||||||
| @ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController | |||||||
| class MetaWorldController(BaseController): | class MetaWorldController(BaseController): | ||||||
|     """ |     """ | ||||||
|     A Metaworld Controller. Using position and velocity information from a provided environment, |     A Metaworld Controller. Using position and velocity information from a provided environment, | ||||||
|     the controller calculates a response based on the desired position and velocity. |     the tracking_controller calculates a response based on the desired position and velocity. | ||||||
|     Unlike the other Controllers, this is a special controller for MetaWorld environments. |     Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments. | ||||||
|     They use a position delta for the xyz coordinates and a raw position for the gripper opening. |     They use a position delta for the xyz coordinates and a raw position for the gripper opening. | ||||||
| 
 | 
 | ||||||
|     :param env: A position environment |     :param env: A position environment | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController | |||||||
| class PDController(BaseController): | class PDController(BaseController): | ||||||
|     """ |     """ | ||||||
|     A PD-Controller. Using position and velocity information from a provided environment, |     A PD-Controller. Using position and velocity information from a provided environment, | ||||||
|     the controller calculates a response based on the desired position and velocity |     the tracking_controller calculates a response based on the desired position and velocity | ||||||
| 
 | 
 | ||||||
|     :param env: A position environment |     :param env: A position environment | ||||||
|     :param p_gains: Factors for the proportional gains |     :param p_gains: Factors for the proportional gains | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController | |||||||
| 
 | 
 | ||||||
| class PosController(BaseController): | class PosController(BaseController): | ||||||
|     """ |     """ | ||||||
|     A Position Controller. The controller calculates a response only based on the desired position. |     A Position Controller. The tracking_controller calculates a response only based on the desired position. | ||||||
|     """ |     """ | ||||||
|     def get_action(self, des_pos, des_vel, c_pos, c_vel): |     def get_action(self, des_pos, des_vel, c_pos, c_vel): | ||||||
|         return des_pos |         return des_pos | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController | |||||||
| 
 | 
 | ||||||
| class VelController(BaseController): | class VelController(BaseController): | ||||||
|     """ |     """ | ||||||
|     A Velocity Controller. The controller calculates a response only based on the desired velocity. |     A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity. | ||||||
|     """ |     """ | ||||||
|     def get_action(self, des_pos, des_vel, c_pos, c_vel): |     def get_action(self, des_pos, des_vel, c_pos, c_vel): | ||||||
|         return des_vel |         return des_vel | ||||||
|  | |||||||
| @ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator | |||||||
| ALL_TYPES = ["promp", "dmp", "idmp"] | ALL_TYPES = ["promp", "dmp", "idmp"] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_movement_primitive( | def get_trajectory_generator( | ||||||
|         movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs |         trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs | ||||||
|         ): |         ): | ||||||
|     movement_primitives_type = movement_primitives_type.lower() |     trajectory_generator_type = trajectory_generator_type.lower() | ||||||
|     if movement_primitives_type == "promp": |     if trajectory_generator_type == "promp": | ||||||
|         return ProMP(basis_generator, action_dim, **kwargs) |         return ProMP(basis_generator, action_dim, **kwargs) | ||||||
|     elif movement_primitives_type == "dmp": |     elif trajectory_generator_type == "dmp": | ||||||
|         return DMP(basis_generator, action_dim, **kwargs) |         return DMP(basis_generator, action_dim, **kwargs) | ||||||
|     elif movement_primitives_type == 'idmp': |     elif trajectory_generator_type == 'idmp': | ||||||
|         return IDMP(basis_generator, action_dim, **kwargs) |         return IDMP(basis_generator, action_dim, **kwargs) | ||||||
|     else: |     else: | ||||||
|         raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, " |         raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, " | ||||||
|                          f"please choose one of {ALL_TYPES}.") |                          f"please choose one of {ALL_TYPES}.") | ||||||
							
								
								
									
										88
									
								
								alr_envs/mp/raw_interface_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								alr_envs/mp/raw_interface_wrapper.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,88 @@ | |||||||
|  | from typing import Union, Tuple | ||||||
|  | 
 | ||||||
|  | import gym | ||||||
|  | import numpy as np | ||||||
|  | from abc import abstractmethod | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RawInterfaceWrapper(gym.Wrapper): | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     @abstractmethod | ||||||
|  |     def context_mask(self) -> np.ndarray: | ||||||
|  |         """ | ||||||
|  |         This function defines the contexts. The contexts are defined as specific observations. | ||||||
|  |         Returns: | ||||||
|  |             bool array representing the indices of the observations | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         return np.ones(self.env.observation_space.shape[0], dtype=bool) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     @abstractmethod | ||||||
|  |     def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|  |         """ | ||||||
|  |             Returns the current position of the action/control dimension. | ||||||
|  |             The dimensionality has to match the action/control dimension. | ||||||
|  |             This is not required when exclusively using velocity control, | ||||||
|  |             it should, however, be implemented regardless. | ||||||
|  |             E.g. The joint positions that are directly or indirectly controlled by the action. | ||||||
|  |         """ | ||||||
|  |         raise NotImplementedError() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     @abstractmethod | ||||||
|  |     def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: | ||||||
|  |         """ | ||||||
|  |             Returns the current velocity of the action/control dimension. | ||||||
|  |             The dimensionality has to match the action/control dimension. | ||||||
|  |             This is not required when exclusively using position control, | ||||||
|  |             it should, however, be implemented regardless. | ||||||
|  |             E.g. The joint velocities that are directly or indirectly controlled by the action. | ||||||
|  |         """ | ||||||
|  |         raise NotImplementedError() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     @abstractmethod | ||||||
|  |     def dt(self) -> float: | ||||||
|  |         """ | ||||||
|  |         Control frequency of the environment | ||||||
|  |         Returns: float | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |     def do_replanning(self, pos, vel, s, a, t): | ||||||
|  |         # return t % 100 == 0 | ||||||
|  |         # return bool(self.replanning_model(s)) | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  |     def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: | ||||||
|  |         """ | ||||||
|  |         Used to extract the parameters for the motion primitive and other parameters from an action array which might | ||||||
|  |         include other actions like ball releasing time for the beer pong environment. | ||||||
|  |         This only needs to be overwritten if the action space is modified. | ||||||
|  |         Args: | ||||||
|  |             action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if | ||||||
|  |             specified, else only trajectory_generator parameters | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             Tuple: mp_arguments and other arguments | ||||||
|  |         """ | ||||||
|  |         return action, None | ||||||
|  | 
 | ||||||
|  |     def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[ | ||||||
|  |         np.ndarray]: | ||||||
|  |         """ | ||||||
|  |         This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the | ||||||
|  |         Beerpong env. The parameters used should not be part of the motion primitive parameters. | ||||||
|  |         Returns step_action by default, can be overwritten in individual mp_wrappers. | ||||||
|  |         Args: | ||||||
|  |             t: the current time step of the episode | ||||||
|  |             env_spec_params: the environment specific parameter, as defined in function _episode_callback | ||||||
|  |             (e.g. ball release time in Beer Pong) | ||||||
|  |             step_action: the current step-based action | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             modified step action | ||||||
|  |         """ | ||||||
|  |         return step_action | ||||||
| @ -21,7 +21,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "alr_envs:MountainCarContinuous-v1", |         "name": "alr_envs:MountainCarContinuous-v1", | ||||||
|         "wrappers": [classic_control.continuous_mountain_car.MPWrapper], |         "wrappers": [classic_control.continuous_mountain_car.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 4, |             "num_basis": 4, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
| @ -43,7 +43,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.classic_control:MountainCarContinuous-v0", |         "name": "gym.envs.classic_control:MountainCarContinuous-v0", | ||||||
|         "wrappers": [classic_control.continuous_mountain_car.MPWrapper], |         "wrappers": [classic_control.continuous_mountain_car.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 1, |             "num_dof": 1, | ||||||
|             "num_basis": 4, |             "num_basis": 4, | ||||||
|             "duration": 19.98, |             "duration": 19.98, | ||||||
| @ -65,7 +65,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.mujoco:Reacher-v2", |         "name": "gym.envs.mujoco:Reacher-v2", | ||||||
|         "wrappers": [mujoco.reacher_v2.MPWrapper], |         "wrappers": [mujoco.reacher_v2.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 2, |             "num_dof": 2, | ||||||
|             "num_basis": 6, |             "num_basis": 6, | ||||||
|             "duration": 1, |             "duration": 1, | ||||||
| @ -87,7 +87,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.robotics:FetchSlideDense-v1", |         "name": "gym.envs.robotics:FetchSlideDense-v1", | ||||||
|         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], |         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 4, |             "num_dof": 4, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
| @ -105,7 +105,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.robotics:FetchSlide-v1", |         "name": "gym.envs.robotics:FetchSlide-v1", | ||||||
|         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], |         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 4, |             "num_dof": 4, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
| @ -123,7 +123,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.robotics:FetchReachDense-v1", |         "name": "gym.envs.robotics:FetchReachDense-v1", | ||||||
|         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], |         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 4, |             "num_dof": 4, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
| @ -141,7 +141,7 @@ register( | |||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "gym.envs.robotics:FetchReach-v1", |         "name": "gym.envs.robotics:FetchReach-v1", | ||||||
|         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], |         "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], | ||||||
|         "mp_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "num_dof": 4, |             "num_dof": 4, | ||||||
|             "num_basis": 5, |             "num_basis": 5, | ||||||
|             "duration": 2, |             "duration": 2, | ||||||
|  | |||||||
| @ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping | |||||||
| import gym | import gym | ||||||
| import numpy as np | import numpy as np | ||||||
| from gym.envs.registration import EnvSpec | from gym.envs.registration import EnvSpec | ||||||
| 
 |  | ||||||
| from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper |  | ||||||
| from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper |  | ||||||
| from mp_pytorch import MPInterface | from mp_pytorch import MPInterface | ||||||
| 
 | 
 | ||||||
| from alr_envs.mp.basis_generator_factory import get_basis_generator | from alr_envs.mp.basis_generator_factory import get_basis_generator | ||||||
|  | from alr_envs.mp.black_box_wrapper import BlackBoxWrapper | ||||||
| from alr_envs.mp.controllers.base_controller import BaseController | from alr_envs.mp.controllers.base_controller import BaseController | ||||||
| from alr_envs.mp.controllers.controller_factory import get_controller | from alr_envs.mp.controllers.controller_factory import get_controller | ||||||
| from alr_envs.mp.mp_factory import get_movement_primitive | from alr_envs.mp.mp_factory import get_trajectory_generator | ||||||
| from alr_envs.mp.episodic_wrapper import EpisodicWrapper |  | ||||||
| from alr_envs.mp.phase_generator_factory import get_phase_generator | from alr_envs.mp.phase_generator_factory import get_phase_generator | ||||||
|  | from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): | def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): | ||||||
| @ -100,9 +98,8 @@ def make(env_id: str, seed, **kwargs): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _make_wrapped_env( | def _make_wrapped_env( | ||||||
|         env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController, |         env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs | ||||||
|         ep_wrapper_kwargs: Mapping, seed=1, **kwargs | ): | ||||||
|         ): |  | ||||||
|     """ |     """ | ||||||
|     Helper function for creating a wrapped gym environment using MPs. |     Helper function for creating a wrapped gym environment using MPs. | ||||||
|     It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is |     It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is | ||||||
| @ -118,73 +115,74 @@ def _make_wrapped_env( | |||||||
|     """ |     """ | ||||||
|     # _env = gym.make(env_id) |     # _env = gym.make(env_id) | ||||||
|     _env = make(env_id, seed, **kwargs) |     _env = make(env_id, seed, **kwargs) | ||||||
|     has_episodic_wrapper = False |     has_black_box_wrapper = False | ||||||
|     for w in wrappers: |     for w in wrappers: | ||||||
|         # only wrap the environment if not EpisodicWrapper, e.g. for vision |         # only wrap the environment if not BlackBoxWrapper, e.g. for vision | ||||||
|         if not issubclass(w, EpisodicWrapper): |         if issubclass(w, RawInterfaceWrapper): | ||||||
|             _env = w(_env) |             has_black_box_wrapper = True | ||||||
|         else:  # if EpisodicWrapper, use specific constructor |         _env = w(_env) | ||||||
|             has_episodic_wrapper = True |     if not has_black_box_wrapper: | ||||||
|             _env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs) |         raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.") | ||||||
|     if not has_episodic_wrapper: |  | ||||||
|         raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.") |  | ||||||
|     return _env |     return _env | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def make_mp_from_kwargs( | def make_bb_env( | ||||||
|         env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping, |         env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, | ||||||
|         controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1, |         controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1, | ||||||
|         sequenced=False, **kwargs |         sequenced=False, **kwargs): | ||||||
|         ): |  | ||||||
|     """ |     """ | ||||||
|     This can also be used standalone for manually building a custom DMP environment. |     This can also be used standalone for manually building a custom DMP environment. | ||||||
|     Args: |     Args: | ||||||
|         ep_wrapper_kwargs: |         black_box_wrapper_kwargs: kwargs for the black-box wrapper | ||||||
|         basis_kwargs: |         basis_kwargs: kwargs for the basis generator | ||||||
|         phase_kwargs: |         phase_kwargs: kwargs for the phase generator | ||||||
|         controller_kwargs: |         controller_kwargs: kwargs for the tracking controller | ||||||
|         env_id: base_env_name, |         env_id: base_env_name, | ||||||
|         wrappers: list of wrappers (at least an EpisodicWrapper), |         wrappers: list of wrappers (at least an BlackBoxWrapper), | ||||||
|         seed: seed of environment |         seed: seed of environment | ||||||
|         sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory, |         sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory, | ||||||
|                 this behavior is much closer to step based learning. |                 this behavior is much closer to step based learning. | ||||||
|         mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP |         traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP | ||||||
| 
 | 
 | ||||||
|     Returns: DMP wrapped gym env |     Returns: DMP wrapped gym env | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|     _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) |     _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None)) | ||||||
|     dummy_env = make(env_id, seed) |     _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) | ||||||
|     if ep_wrapper_kwargs.get('duration', None) is None: | 
 | ||||||
|         ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt |     if black_box_wrapper_kwargs.get('duration', None) is None: | ||||||
|  |         black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt | ||||||
|     if phase_kwargs.get('tau', None) is None: |     if phase_kwargs.get('tau', None) is None: | ||||||
|         phase_kwargs['tau'] = ep_wrapper_kwargs['duration'] |         phase_kwargs['tau'] = black_box_wrapper_kwargs['duration'] | ||||||
|     mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item()) |     traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item()) | ||||||
|  | 
 | ||||||
|     phase_gen = get_phase_generator(**phase_kwargs) |     phase_gen = get_phase_generator(**phase_kwargs) | ||||||
|     basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs) |     basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs) | ||||||
|     controller = get_controller(**controller_kwargs) |     controller = get_controller(**controller_kwargs) | ||||||
|     mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs) |     traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs) | ||||||
|     _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller, | 
 | ||||||
|                              ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs) |     bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller, | ||||||
|     return _env |                              **black_box_wrapper_kwargs) | ||||||
|  | 
 | ||||||
|  |     return bb_env | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def make_mp_env_helper(**kwargs): | def make_bb_env_helper(**kwargs): | ||||||
|     """ |     """ | ||||||
|     Helper function for registering a DMP gym environments. |     Helper function for registering a black box gym environment. | ||||||
|     Args: |     Args: | ||||||
|         **kwargs: expects at least the following: |         **kwargs: expects at least the following: | ||||||
|         { |         { | ||||||
|         "name": base environment name. |         "name": base environment name. | ||||||
|         "wrappers": list of wrappers (at least an EpisodicWrapper is required), |         "wrappers": list of wrappers (at least an BlackBoxWrapper is required), | ||||||
|         "movement_primitives_kwargs": { |         "traj_gen_kwargs": { | ||||||
|             "movement_primitives_type": type_of_your_movement_primitive, |             "trajectory_generator_type": type_of_your_movement_primitive, | ||||||
|             non default arguments for the movement primitive instance |             non default arguments for the movement primitive instance | ||||||
|             ... |             ... | ||||||
|             } |             } | ||||||
|         "controller_kwargs": { |         "controller_kwargs": { | ||||||
|             "controller_type": type_of_your_controller, |             "controller_type": type_of_your_controller, | ||||||
|             non default arguments for the controller instance |             non default arguments for the tracking_controller instance | ||||||
|             ... |             ... | ||||||
|             }, |             }, | ||||||
|         "basis_generator_kwargs": { |         "basis_generator_kwargs": { | ||||||
| @ -205,95 +203,17 @@ def make_mp_env_helper(**kwargs): | |||||||
|     seed = kwargs.pop("seed", None) |     seed = kwargs.pop("seed", None) | ||||||
|     wrappers = kwargs.pop("wrappers") |     wrappers = kwargs.pop("wrappers") | ||||||
| 
 | 
 | ||||||
|     mp_kwargs = kwargs.pop("movement_primitives_kwargs") |     traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) | ||||||
|     ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs') |     black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {}) | ||||||
|     contr_kwargs = kwargs.pop("controller_kwargs") |     contr_kwargs = kwargs.pop("controller_kwargs", {}) | ||||||
|     phase_kwargs = kwargs.pop("phase_generator_kwargs") |     phase_kwargs = kwargs.pop("phase_generator_kwargs", {}) | ||||||
|     basis_kwargs = kwargs.pop("basis_generator_kwargs") |     basis_kwargs = kwargs.pop("basis_generator_kwargs", {}) | ||||||
| 
 | 
 | ||||||
|     return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs, |     return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers, | ||||||
|                                mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs, |                        black_box_wrapper_kwargs=black_box_kwargs, | ||||||
|                                basis_kwargs=basis_kwargs, **kwargs, seed=seed) |                        traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, | ||||||
| 
 |                        phase_kwargs=phase_kwargs, | ||||||
| 
 |                        basis_kwargs=basis_kwargs, **kwargs, seed=seed) | ||||||
| def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): |  | ||||||
|     """ |  | ||||||
|     This can also be used standalone for manually building a custom DMP environment. |  | ||||||
|     Args: |  | ||||||
|         env_id: base_env_name, |  | ||||||
|         wrappers: list of wrappers (at least an MPEnvWrapper), |  | ||||||
|         seed: seed of environment |  | ||||||
|         mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP |  | ||||||
| 
 |  | ||||||
|     Returns: DMP wrapped gym env |  | ||||||
| 
 |  | ||||||
|     """ |  | ||||||
|     _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) |  | ||||||
| 
 |  | ||||||
|     _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) |  | ||||||
| 
 |  | ||||||
|     _verify_dof(_env, mp_kwargs.get("num_dof")) |  | ||||||
| 
 |  | ||||||
|     return DmpWrapper(_env, **mp_kwargs) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): |  | ||||||
|     """ |  | ||||||
|     This can also be used standalone for manually building a custom ProMP environment. |  | ||||||
|     Args: |  | ||||||
|         env_id: base_env_name, |  | ||||||
|         wrappers: list of wrappers (at least an MPEnvWrapper), |  | ||||||
|         mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} |  | ||||||
| 
 |  | ||||||
|     Returns: ProMP wrapped gym env |  | ||||||
| 
 |  | ||||||
|     """ |  | ||||||
|     _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) |  | ||||||
| 
 |  | ||||||
|     _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) |  | ||||||
| 
 |  | ||||||
|     _verify_dof(_env, mp_kwargs.get("num_dof")) |  | ||||||
| 
 |  | ||||||
|     return ProMPWrapper(_env, **mp_kwargs) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def make_dmp_env_helper(**kwargs): |  | ||||||
|     """ |  | ||||||
|     Helper function for registering a DMP gym environments. |  | ||||||
|     Args: |  | ||||||
|         **kwargs: expects at least the following: |  | ||||||
|         { |  | ||||||
|         "name": base_env_name, |  | ||||||
|         "wrappers": list of wrappers (at least an MPEnvWrapper), |  | ||||||
|         "mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     Returns: DMP wrapped gym env |  | ||||||
| 
 |  | ||||||
|     """ |  | ||||||
|     seed = kwargs.pop("seed", None) |  | ||||||
|     return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed, |  | ||||||
|                         mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def make_promp_env_helper(**kwargs): |  | ||||||
|     """ |  | ||||||
|     Helper function for registering ProMP gym environments. |  | ||||||
|     This can also be used standalone for manually building a custom ProMP environment. |  | ||||||
|     Args: |  | ||||||
|         **kwargs: expects at least the following: |  | ||||||
|         { |  | ||||||
|         "name": base_env_name, |  | ||||||
|         "wrappers": list of wrappers (at least an MPEnvWrapper), |  | ||||||
|         "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int} |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     Returns: ProMP wrapped gym env |  | ||||||
| 
 |  | ||||||
|     """ |  | ||||||
|     seed = kwargs.pop("seed", None) |  | ||||||
|     return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed, |  | ||||||
|                           mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): | def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): | ||||||
| @ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[ | |||||||
|     It can be found in the BaseMP class. |     It can be found in the BaseMP class. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|         mp_time_limit: max trajectory length of mp in seconds |         mp_time_limit: max trajectory length of trajectory_generator in seconds | ||||||
|         env_time_limit: max trajectory length of DMC environment in seconds |         env_time_limit: max trajectory length of DMC environment in seconds | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user