diff --git a/README.md b/README.md index c08a1d4..607af63 100644 --- a/README.md +++ b/README.md @@ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper] mp_kwargs = {...} kwargs = {...} env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs) -# OR for a deterministic ProMP (other mp_kwargs are required): -# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) +# OR for a deterministic ProMP (other traj_gen_kwargs are required): +# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args) rewards = 0 obs = env.reset() diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index ec539db..d8169a8 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -346,7 +346,7 @@ for _v in _versions: kwargs={ "name": f"alr_envs:{_v}", "wrappers": [classic_control.simple_reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2 if "long" not in _v.lower() else 5, "num_basis": 5, "duration": 2, @@ -386,7 +386,7 @@ register( kwargs={ "name": "alr_envs:ViaPointReacher-v0", "wrappers": [classic_control.viapoint_reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 5, "num_basis": 5, "duration": 2, @@ -424,7 +424,7 @@ for _v in _versions: kwargs={ "name": f"alr_envs:HoleReacher-{_v}", "wrappers": [classic_control.hole_reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 5, "num_basis": 5, "duration": 2, @@ -467,7 +467,7 @@ for _v in _versions: kwargs={ "name": f"alr_envs:{_v}", "wrappers": [mujoco.reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 5 if "long" not in _v.lower() else 7, "num_basis": 2, "duration": 4, diff --git a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py index c33048f..c12aa56 100644 --- a/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/ant_jump/new_mp_wrapper.py @@ -1,12 +1,13 @@ -from alr_envs.mp.episodic_wrapper import EpisodicWrapper +from alr_envs.mp.black_box_wrapper import BlackBoxWrapper from typing import Union, Tuple import numpy as np +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class NewMPWrapper(EpisodicWrapper): +class NewMPWrapper(RawInterfaceWrapper): - def set_active_obs(self): + def get_context_mask(self): return np.hstack([ [False] * 111, # ant has 111 dimensional observation space !! [True] # goal height diff --git a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py index 53d7a1a..0df1a7c 100644 --- a/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/beerpong/new_mp_wrapper.py @@ -1,15 +1,11 @@ -from typing import Tuple, Union +from typing import Union, Tuple import numpy as np -from alr_envs.mp.episodic_wrapper import EpisodicWrapper +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class NewMPWrapper(EpisodicWrapper): - - # def __init__(self, replanning_model): - # self.replanning_model = replanning_model - +class NewMPWrapper(RawInterfaceWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qpos[0:7].copy() @@ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qvel[0:7].copy() - def set_active_obs(self): + def get_context_mask(self): return np.hstack([ [False] * 7, # cos [False] * 7, # sin @@ -27,12 +23,7 @@ class NewMPWrapper(EpisodicWrapper): [False] * 3, # cup_goal_diff_top [True] * 2, # xy position of cup [False] # env steps - ]) - - def do_replanning(self, pos, vel, s, a, t, last_replan_step): - return False - # const = np.arange(0, 1000, 10) - # return bool(self.replanning_model(s)) + ]) def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: if self.mp.learn_tau: diff --git a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py b/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py index 77a1bf6..ccd8f76 100644 --- a/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/hopper_jump/new_mp_wrapper.py @@ -1,9 +1,9 @@ -from alr_envs.mp.episodic_wrapper import EpisodicWrapper +from alr_envs.mp.black_box_wrapper import BlackBoxWrapper from typing import Union, Tuple import numpy as np -class NewMPWrapper(EpisodicWrapper): +class NewMPWrapper(BlackBoxWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qpos[3:6].copy() @@ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper): # ]) # Random x goal + random init pos - def set_active_obs(self): + def get_context_mask(self): return np.hstack([ [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position [True] * 3, # set to true if randomize initial pos @@ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper): class NewHighCtxtMPWrapper(NewMPWrapper): - def set_active_obs(self): + def get_context_mask(self): return np.hstack([ [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position [True] * 3, # set to true if randomize initial pos diff --git a/alr_envs/alr/mujoco/reacher/alr_reacher.py b/alr_envs/alr/mujoco/reacher/alr_reacher.py index c12352a..0699c44 100644 --- a/alr_envs/alr/mujoco/reacher/alr_reacher.py +++ b/alr_envs/alr/mujoco/reacher/alr_reacher.py @@ -149,4 +149,4 @@ if __name__ == '__main__': if d: env.reset() - env.close() \ No newline at end of file + env.close() diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py index 02dc1d8..8df365a 100644 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py @@ -1,9 +1,9 @@ -from alr_envs.mp.episodic_wrapper import EpisodicWrapper +from alr_envs.mp.black_box_wrapper import BlackBoxWrapper from typing import Union, Tuple import numpy as np -class MPWrapper(EpisodicWrapper): +class MPWrapper(BlackBoxWrapper): @property def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: @@ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper): def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: return self.env.sim.data.qvel.flat[:self.env.n_links] - def set_active_obs(self): + def get_context_mask(self): return np.concatenate([ [False] * self.env.n_links, # cos [False] * self.env.n_links, # sin diff --git a/alr_envs/dmc/__init__.py b/alr_envs/dmc/__init__.py index ac34415..dc3adf0 100644 --- a/alr_envs/dmc/__init__.py +++ b/alr_envs/dmc/__init__.py @@ -15,7 +15,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.ball_in_cup.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -41,7 +41,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.ball_in_cup.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -65,7 +65,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -92,7 +92,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -117,7 +117,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -144,7 +144,7 @@ register( "time_limit": 20, "episode_length": 1000, "wrappers": [suite.reacher.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 5, "duration": 20, @@ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks: "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks: "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -230,7 +230,7 @@ register( "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.TwoPolesMPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -259,7 +259,7 @@ register( "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.TwoPolesMPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -286,7 +286,7 @@ register( "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.ThreePolesMPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -315,7 +315,7 @@ register( "camera_id": 0, "episode_length": 1000, "wrappers": [suite.cartpole.ThreePolesMPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 5, "duration": 10, @@ -342,7 +342,7 @@ register( # "time_limit": 1, "episode_length": 250, "wrappers": [manipulation.reach_site.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 9, "num_basis": 5, "duration": 10, @@ -365,7 +365,7 @@ register( # "time_limit": 1, "episode_length": 250, "wrappers": [manipulation.reach_site.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 9, "num_basis": 5, "duration": 10, diff --git a/alr_envs/examples/examples_dmc.py b/alr_envs/examples/examples_dmc.py index d223d3c..5658b1f 100644 --- a/alr_envs/examples/examples_dmc.py +++ b/alr_envs/examples/examples_dmc.py @@ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): "learn_goal": True, # learn the goal position (recommended) "alpha_phase": 2, "bandwidth_factor": 2, - "policy_type": "motor", # controller type, 'velocity', 'position', and 'motor' (torque control) + "policy_type": "motor", # tracking_controller type, 'velocity', 'position', and 'motor' (torque control) "weights_scale": 1, # scaling of MP weights "goal_scale": 1, # scaling of learned goal position "policy_kwargs": { # only required for torque control/PD-Controller @@ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): # "frame_skip": 1 } env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) - # OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples): - # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) + # OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples): + # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args) # This renders the full MP trajectory # It is only required to call render() once in the beginning, which renders every consecutive trajectory. diff --git a/alr_envs/examples/examples_metaworld.py b/alr_envs/examples/examples_metaworld.py index 9ead50c..3e040cc 100644 --- a/alr_envs/examples/examples_metaworld.py +++ b/alr_envs/examples/examples_metaworld.py @@ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): "width": 0.025, # width of the basis functions "zero_start": True, # start from current environment position if True "weights_scale": 1, # scaling of MP weights - "policy_type": "metaworld", # custom controller type for metaworld environments + "policy_type": "metaworld", # custom tracking_controller type for metaworld environments } env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) - # OR for a DMP (other mp_kwargs are required, see dmc_examples): - # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) + # OR for a DMP (other traj_gen_kwargs are required, see dmc_examples): + # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs) # This renders the full MP trajectory # It is only required to call render() once in the beginning, which renders every consecutive trajectory. diff --git a/alr_envs/examples/examples_motion_primitives.py b/alr_envs/examples/examples_motion_primitives.py index 1a679df..b9d355a 100644 --- a/alr_envs/examples/examples_motion_primitives.py +++ b/alr_envs/examples/examples_motion_primitives.py @@ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations= Returns: """ - # Changing the mp_kwargs is possible by providing them to gym. + # Changing the traj_gen_kwargs is possible by providing them to gym. # E.g. here by providing way to many basis functions mp_kwargs = { "num_dof": 5, @@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): } env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) # OR for a deterministic ProMP: - # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) + # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs) if render: env.render(mode="human") diff --git a/alr_envs/examples/examples_open_ai.py b/alr_envs/examples/examples_open_ai.py index dc0c558..631a3a1 100644 --- a/alr_envs/examples/examples_open_ai.py +++ b/alr_envs/examples/examples_open_ai.py @@ -4,7 +4,7 @@ import alr_envs def example_mp(env_name, seed=1): """ Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. - For more information on motion primitive specific stuff, look at the mp examples. + For more information on motion primitive specific stuff, look at the trajectory_generator examples. Args: env_name: ProMP env_id seed: seed diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index 1683c6f..e05bad1 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env def visualize(env): t = env.t - pos_features = env.mp.basis_generator.basis(t) + pos_features = env.trajectory_generator.basis_generator.basis(t) plt.plot(t, pos_features) plt.show() diff --git a/alr_envs/meta/__init__.py b/alr_envs/meta/__init__.py index 5651224..e0d0ea0 100644 --- a/alr_envs/meta/__init__.py +++ b/alr_envs/meta/__init__.py @@ -19,7 +19,7 @@ for _task in _goal_change_envs: kwargs={ "name": _task, "wrappers": [goal_change_mp_wrapper.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 6.25, @@ -42,7 +42,7 @@ for _task in _object_change_envs: kwargs={ "name": _task, "wrappers": [object_change_mp_wrapper.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 6.25, @@ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs: kwargs={ "name": _task, "wrappers": [goal_object_change_mp_wrapper.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 6.25, @@ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs: kwargs={ "name": _task, "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 6.25, diff --git a/alr_envs/mp/episodic_wrapper.py b/alr_envs/mp/black_box_wrapper.py similarity index 51% rename from alr_envs/mp/episodic_wrapper.py rename to alr_envs/mp/black_box_wrapper.py index b2eb391..5ae0ff9 100644 --- a/alr_envs/mp/episodic_wrapper.py +++ b/alr_envs/mp/black_box_wrapper.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod -from typing import Union, Tuple +from abc import ABC +from typing import Tuple import gym import numpy as np @@ -7,77 +7,77 @@ from gym import spaces from mp_pytorch.mp.mp_interfaces import MPInterface from alr_envs.mp.controllers.base_controller import BaseController +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper -class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): - """ - Base class for movement primitive based gym.Wrapper implementations. +class BlackBoxWrapper(gym.ObservationWrapper, ABC): - Args: - env: The (wrapped) environment this wrapper is applied on - num_dof: Dimension of the action space of the wrapped env - num_basis: Number of basis functions per dof - duration: Length of the trajectory of the movement primitive in seconds - controller: Type or object defining the policy that is used to generate action based on the trajectory - weight_scale: Scaling parameter for the actions given to this wrapper - render_mode: Equivalent to gym render mode - """ + def __init__(self, + env: RawInterfaceWrapper, + trajectory_generator: MPInterface, tracking_controller: BaseController, + duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum): + """ + gym.Wrapper for leveraging a black box approach with a trajectory generator. - def __init__( - self, - env: gym.Env, - mp: MPInterface, - controller: BaseController, - duration: float, - render_mode: str = None, - verbose: int = 1, - weight_scale: float = 1, - sequencing=True, - reward_aggregation=np.mean, - ): + Args: + env: The (wrapped) environment this wrapper is applied on + trajectory_generator: Generates the full or partial trajectory + tracking_controller: Translates the desired trajectory to raw action sequences + duration: Length of the trajectory of the movement primitive in seconds + verbose: level of detail for returned values in info dict. + reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory + reward, default summation over all values. + """ super().__init__() self.env = env - try: - self.dt = env.dt - except AttributeError: - raise AttributeError("step based environment needs to have a function 'dt' ") self.duration = duration self.traj_steps = int(duration / self.dt) self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps # duration = self.env.max_episode_steps * self.dt - self.mp = mp - self.env = env - self.controller = controller - self.weight_scale = weight_scale - - # rendering - self.render_mode = render_mode - self.render_kwargs = {} + # trajectory generation + self.trajectory_generator = trajectory_generator + self.tracking_controller = tracking_controller + # self.weight_scale = weight_scale self.time_steps = np.linspace(0, self.duration, self.traj_steps) - self.mp.set_mp_times(self.time_steps) - # self.mp.set_mp_duration(self.time_steps, dt) - # action_bounds = np.inf * np.ones((np.prod(self.mp.num_params))) - self.mp_action_space = self.get_mp_action_space() + self.trajectory_generator.set_mp_times(self.time_steps) + # self.trajectory_generator.set_mp_duration(self.time_steps, dt) + # action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params))) + self.reward_aggregation = reward_aggregation + # spaces + self.mp_action_space = self.get_mp_action_space() self.action_space = self.get_action_space() - self.active_obs = self.set_active_obs() - self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs], - high=self.env.observation_space.high[self.active_obs], + self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask], + high=self.env.observation_space.high[self.env.context_mask], dtype=self.env.observation_space.dtype) + # rendering + self.render_mode = None + self.render_kwargs = {} + self.verbose = verbose + @property + def dt(self): + return self.env.dt + + def observation(self, observation): + return observation[self.env.context_mask] + def get_trajectory(self, action: np.ndarray) -> Tuple: # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at # the beginning of the array. - ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay) - scaled_mp_params = action.copy() - scaled_mp_params[ignore_indices:] *= self.weight_scale - self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high)) - self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel) - traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True) + # ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay) + # scaled_mp_params = action.copy() + # scaled_mp_params[ignore_indices:] *= self.weight_scale + + clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high) + self.trajectory_generator.set_params(clipped_params) + self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, + bc_vel=self.current_vel) + traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True) trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] trajectory = trajectory_tensor.numpy() @@ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): # TODO: Do we need this or does mp_pytorch have this? if self.post_traj_steps > 0: trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])]) - velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))]) + velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))]) return trajectory, velocity def get_mp_action_space(self): - """This function can be used to set up an individual space for the parameters of the mp.""" - min_action_bounds, max_action_bounds = self.mp.get_param_bounds() + """This function can be used to set up an individual space for the parameters of the trajectory_generator.""" + min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds() mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), dtype=np.float32) return mp_action_space @@ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): except AttributeError: return self.get_mp_action_space() - def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: - """ - Used to extract the parameters for the motion primitive and other parameters from an action array which might - include other actions like ball releasing time for the beer pong environment. - This only needs to be overwritten if the action space is modified. - Args: - action: a vector instance of the whole action space, includes mp parameters and additional parameters if - specified, else only mp parameters - - Returns: - Tuple: mp_arguments and other arguments - """ - return action, None - - def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[ - np.ndarray]: - """ - This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the - Beerpong env. The parameters used should not be part of the motion primitive parameters. - Returns step_action by default, can be overwritten in individual mp_wrappers. - Args: - t: the current time step of the episode - env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback - (e.g. ball release time in Beer Pong) - step_action: the current step-based action - - Returns: - modified step action - """ - return step_action - - @abstractmethod - def set_active_obs(self) -> np.ndarray: - """ - This function defines the contexts. The contexts are defined as specific observations. - Returns: - boolearn array representing the indices of the observations - - """ - return np.ones(self.env.observation_space.shape[0], dtype=bool) - - @property - @abstractmethod - def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: - """ - Returns the current position of the action/control dimension. - The dimensionality has to match the action/control dimension. - This is not required when exclusively using velocity control, - it should, however, be implemented regardless. - E.g. The joint positions that are directly or indirectly controlled by the action. - """ - raise NotImplementedError() - - @property - @abstractmethod - def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: - """ - Returns the current velocity of the action/control dimension. - The dimensionality has to match the action/control dimension. - This is not required when exclusively using position control, - it should, however, be implemented regardless. - E.g. The joint velocities that are directly or indirectly controlled by the action. - """ - raise NotImplementedError() - def step(self, action: np.ndarray): """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" # TODO: Think about sequencing @@ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): # TODO # self.time_steps = np.linspace(0, learned_duration, self.traj_steps) - # self.mp.set_mp_times(self.time_steps) + # self.trajectory_generator.set_mp_times(self.time_steps) trajectory_length = len(trajectory) + rewards = np.zeros(shape=(trajectory_length,)) if self.verbose >= 2: actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape, dtype=self.env.observation_space.dtype) - rewards = np.zeros(shape=(trajectory_length,)) - trajectory_return = 0 infos = dict() + done = False for t, pos_vel in enumerate(zip(trajectory, velocity)): - step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel) + step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, + self.current_vel) step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) # print('step/clipped action ratio: ', step_action/c_action) obs, c_reward, done, info = self.env.step(c_action) + rewards[t] = c_reward + if self.verbose >= 2: actions[t, :] = c_action - rewards[t] = c_reward observations[t, :] = obs - trajectory_return += c_reward + for k, v in info.items(): elems = infos.get(k, [None] * trajectory_length) elems[t] = v infos[k] = elems - # infos['step_infos'].append(info) - if self.render_mode: + + if self.render_mode is not None: self.render(mode=self.render_mode, **self.render_kwargs) - if done or do_replanning(kwargs): + + if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t): break + infos.update({k: v[:t + 1] for k, v in infos.items()}) + if self.verbose >= 2: infos['trajectory'] = trajectory infos['step_actions'] = actions[:t + 1] infos['step_observations'] = observations[:t + 1] infos['step_rewards'] = rewards[:t + 1] + infos['trajectory_length'] = t + 1 - done = True + trajectory_return = self.reward_aggregation(rewards[:t + 1]) return self.get_observation_from_step(obs), trajectory_return, done, infos def reset(self): diff --git a/alr_envs/mp/controllers/meta_world_controller.py b/alr_envs/mp/controllers/meta_world_controller.py index 07988e5..5747f9e 100644 --- a/alr_envs/mp/controllers/meta_world_controller.py +++ b/alr_envs/mp/controllers/meta_world_controller.py @@ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController class MetaWorldController(BaseController): """ A Metaworld Controller. Using position and velocity information from a provided environment, - the controller calculates a response based on the desired position and velocity. - Unlike the other Controllers, this is a special controller for MetaWorld environments. + the tracking_controller calculates a response based on the desired position and velocity. + Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments. They use a position delta for the xyz coordinates and a raw position for the gripper opening. :param env: A position environment diff --git a/alr_envs/mp/controllers/pd_controller.py b/alr_envs/mp/controllers/pd_controller.py index d22f6b4..140aeee 100644 --- a/alr_envs/mp/controllers/pd_controller.py +++ b/alr_envs/mp/controllers/pd_controller.py @@ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController class PDController(BaseController): """ A PD-Controller. Using position and velocity information from a provided environment, - the controller calculates a response based on the desired position and velocity + the tracking_controller calculates a response based on the desired position and velocity :param env: A position environment :param p_gains: Factors for the proportional gains diff --git a/alr_envs/mp/controllers/pos_controller.py b/alr_envs/mp/controllers/pos_controller.py index bec3c68..5570307 100644 --- a/alr_envs/mp/controllers/pos_controller.py +++ b/alr_envs/mp/controllers/pos_controller.py @@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController class PosController(BaseController): """ - A Position Controller. The controller calculates a response only based on the desired position. + A Position Controller. The tracking_controller calculates a response only based on the desired position. """ def get_action(self, des_pos, des_vel, c_pos, c_vel): return des_pos diff --git a/alr_envs/mp/controllers/vel_controller.py b/alr_envs/mp/controllers/vel_controller.py index 38128be..67bab2a 100644 --- a/alr_envs/mp/controllers/vel_controller.py +++ b/alr_envs/mp/controllers/vel_controller.py @@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController class VelController(BaseController): """ - A Velocity Controller. The controller calculates a response only based on the desired velocity. + A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity. """ def get_action(self, des_pos, des_vel, c_pos, c_vel): return des_vel diff --git a/alr_envs/mp/mp_factory.py b/alr_envs/mp/mp_factory.py index 5cf7231..d2c5460 100644 --- a/alr_envs/mp/mp_factory.py +++ b/alr_envs/mp/mp_factory.py @@ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator ALL_TYPES = ["promp", "dmp", "idmp"] -def get_movement_primitive( - movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs +def get_trajectory_generator( + trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs ): - movement_primitives_type = movement_primitives_type.lower() - if movement_primitives_type == "promp": + trajectory_generator_type = trajectory_generator_type.lower() + if trajectory_generator_type == "promp": return ProMP(basis_generator, action_dim, **kwargs) - elif movement_primitives_type == "dmp": + elif trajectory_generator_type == "dmp": return DMP(basis_generator, action_dim, **kwargs) - elif movement_primitives_type == 'idmp': + elif trajectory_generator_type == 'idmp': return IDMP(basis_generator, action_dim, **kwargs) else: - raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, " + raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, " f"please choose one of {ALL_TYPES}.") \ No newline at end of file diff --git a/alr_envs/mp/raw_interface_wrapper.py b/alr_envs/mp/raw_interface_wrapper.py new file mode 100644 index 0000000..45d5daf --- /dev/null +++ b/alr_envs/mp/raw_interface_wrapper.py @@ -0,0 +1,88 @@ +from typing import Union, Tuple + +import gym +import numpy as np +from abc import abstractmethod + + +class RawInterfaceWrapper(gym.Wrapper): + + @property + @abstractmethod + def context_mask(self) -> np.ndarray: + """ + This function defines the contexts. The contexts are defined as specific observations. + Returns: + bool array representing the indices of the observations + + """ + return np.ones(self.env.observation_space.shape[0], dtype=bool) + + @property + @abstractmethod + def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: + """ + Returns the current position of the action/control dimension. + The dimensionality has to match the action/control dimension. + This is not required when exclusively using velocity control, + it should, however, be implemented regardless. + E.g. The joint positions that are directly or indirectly controlled by the action. + """ + raise NotImplementedError() + + @property + @abstractmethod + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + """ + Returns the current velocity of the action/control dimension. + The dimensionality has to match the action/control dimension. + This is not required when exclusively using position control, + it should, however, be implemented regardless. + E.g. The joint velocities that are directly or indirectly controlled by the action. + """ + raise NotImplementedError() + + @property + @abstractmethod + def dt(self) -> float: + """ + Control frequency of the environment + Returns: float + + """ + + def do_replanning(self, pos, vel, s, a, t): + # return t % 100 == 0 + # return bool(self.replanning_model(s)) + return False + + def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: + """ + Used to extract the parameters for the motion primitive and other parameters from an action array which might + include other actions like ball releasing time for the beer pong environment. + This only needs to be overwritten if the action space is modified. + Args: + action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if + specified, else only trajectory_generator parameters + + Returns: + Tuple: mp_arguments and other arguments + """ + return action, None + + def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[ + np.ndarray]: + """ + This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the + Beerpong env. The parameters used should not be part of the motion primitive parameters. + Returns step_action by default, can be overwritten in individual mp_wrappers. + Args: + t: the current time step of the episode + env_spec_params: the environment specific parameter, as defined in function _episode_callback + (e.g. ball release time in Beer Pong) + step_action: the current step-based action + + Returns: + modified step action + """ + return step_action diff --git a/alr_envs/open_ai/__init__.py b/alr_envs/open_ai/__init__.py index 41b770f..04610fa 100644 --- a/alr_envs/open_ai/__init__.py +++ b/alr_envs/open_ai/__init__.py @@ -21,7 +21,7 @@ register( kwargs={ "name": "alr_envs:MountainCarContinuous-v1", "wrappers": [classic_control.continuous_mountain_car.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 4, "duration": 2, @@ -43,7 +43,7 @@ register( kwargs={ "name": "gym.envs.classic_control:MountainCarContinuous-v0", "wrappers": [classic_control.continuous_mountain_car.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 1, "num_basis": 4, "duration": 19.98, @@ -65,7 +65,7 @@ register( kwargs={ "name": "gym.envs.mujoco:Reacher-v2", "wrappers": [mujoco.reacher_v2.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 2, "num_basis": 6, "duration": 1, @@ -87,7 +87,7 @@ register( kwargs={ "name": "gym.envs.robotics:FetchSlideDense-v1", "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 2, @@ -105,7 +105,7 @@ register( kwargs={ "name": "gym.envs.robotics:FetchSlide-v1", "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 2, @@ -123,7 +123,7 @@ register( kwargs={ "name": "gym.envs.robotics:FetchReachDense-v1", "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 2, @@ -141,7 +141,7 @@ register( kwargs={ "name": "gym.envs.robotics:FetchReach-v1", "wrappers": [FlattenObservation, robotics.fetch.MPWrapper], - "mp_kwargs": { + "traj_gen_kwargs": { "num_dof": 4, "num_basis": 5, "duration": 2, diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index dbefeee..9af0a2d 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping import gym import numpy as np from gym.envs.registration import EnvSpec - -from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper -from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper from mp_pytorch import MPInterface from alr_envs.mp.basis_generator_factory import get_basis_generator +from alr_envs.mp.black_box_wrapper import BlackBoxWrapper from alr_envs.mp.controllers.base_controller import BaseController from alr_envs.mp.controllers.controller_factory import get_controller -from alr_envs.mp.mp_factory import get_movement_primitive -from alr_envs.mp.episodic_wrapper import EpisodicWrapper +from alr_envs.mp.mp_factory import get_trajectory_generator from alr_envs.mp.phase_generator_factory import get_phase_generator +from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): @@ -100,9 +98,8 @@ def make(env_id: str, seed, **kwargs): def _make_wrapped_env( - env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController, - ep_wrapper_kwargs: Mapping, seed=1, **kwargs - ): + env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs +): """ Helper function for creating a wrapped gym environment using MPs. It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is @@ -118,73 +115,74 @@ def _make_wrapped_env( """ # _env = gym.make(env_id) _env = make(env_id, seed, **kwargs) - has_episodic_wrapper = False + has_black_box_wrapper = False for w in wrappers: - # only wrap the environment if not EpisodicWrapper, e.g. for vision - if not issubclass(w, EpisodicWrapper): - _env = w(_env) - else: # if EpisodicWrapper, use specific constructor - has_episodic_wrapper = True - _env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs) - if not has_episodic_wrapper: - raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.") + # only wrap the environment if not BlackBoxWrapper, e.g. for vision + if issubclass(w, RawInterfaceWrapper): + has_black_box_wrapper = True + _env = w(_env) + if not has_black_box_wrapper: + raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.") return _env -def make_mp_from_kwargs( - env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping, +def make_bb_env( + env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1, - sequenced=False, **kwargs - ): + sequenced=False, **kwargs): """ This can also be used standalone for manually building a custom DMP environment. Args: - ep_wrapper_kwargs: - basis_kwargs: - phase_kwargs: - controller_kwargs: + black_box_wrapper_kwargs: kwargs for the black-box wrapper + basis_kwargs: kwargs for the basis generator + phase_kwargs: kwargs for the phase generator + controller_kwargs: kwargs for the tracking controller env_id: base_env_name, - wrappers: list of wrappers (at least an EpisodicWrapper), + wrappers: list of wrappers (at least an BlackBoxWrapper), seed: seed of environment sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory, this behavior is much closer to step based learning. - mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP + traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP Returns: DMP wrapped gym env """ - _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) - dummy_env = make(env_id, seed) - if ep_wrapper_kwargs.get('duration', None) is None: - ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt + _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None)) + _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + + if black_box_wrapper_kwargs.get('duration', None) is None: + black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt if phase_kwargs.get('tau', None) is None: - phase_kwargs['tau'] = ep_wrapper_kwargs['duration'] - mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item()) + phase_kwargs['tau'] = black_box_wrapper_kwargs['duration'] + traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item()) + phase_gen = get_phase_generator(**phase_kwargs) basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs) controller = get_controller(**controller_kwargs) - mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs) - _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller, - ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs) - return _env + traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs) + + bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller, + **black_box_wrapper_kwargs) + + return bb_env -def make_mp_env_helper(**kwargs): +def make_bb_env_helper(**kwargs): """ - Helper function for registering a DMP gym environments. + Helper function for registering a black box gym environment. Args: **kwargs: expects at least the following: { "name": base environment name. - "wrappers": list of wrappers (at least an EpisodicWrapper is required), - "movement_primitives_kwargs": { - "movement_primitives_type": type_of_your_movement_primitive, + "wrappers": list of wrappers (at least an BlackBoxWrapper is required), + "traj_gen_kwargs": { + "trajectory_generator_type": type_of_your_movement_primitive, non default arguments for the movement primitive instance ... } "controller_kwargs": { "controller_type": type_of_your_controller, - non default arguments for the controller instance + non default arguments for the tracking_controller instance ... }, "basis_generator_kwargs": { @@ -205,95 +203,17 @@ def make_mp_env_helper(**kwargs): seed = kwargs.pop("seed", None) wrappers = kwargs.pop("wrappers") - mp_kwargs = kwargs.pop("movement_primitives_kwargs") - ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs') - contr_kwargs = kwargs.pop("controller_kwargs") - phase_kwargs = kwargs.pop("phase_generator_kwargs") - basis_kwargs = kwargs.pop("basis_generator_kwargs") + traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {}) + black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {}) + contr_kwargs = kwargs.pop("controller_kwargs", {}) + phase_kwargs = kwargs.pop("phase_generator_kwargs", {}) + basis_kwargs = kwargs.pop("basis_generator_kwargs", {}) - return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs, - mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs, - basis_kwargs=basis_kwargs, **kwargs, seed=seed) - - -def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): - """ - This can also be used standalone for manually building a custom DMP environment. - Args: - env_id: base_env_name, - wrappers: list of wrappers (at least an MPEnvWrapper), - seed: seed of environment - mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP - - Returns: DMP wrapped gym env - - """ - _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) - - _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) - - _verify_dof(_env, mp_kwargs.get("num_dof")) - - return DmpWrapper(_env, **mp_kwargs) - - -def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs): - """ - This can also be used standalone for manually building a custom ProMP environment. - Args: - env_id: base_env_name, - wrappers: list of wrappers (at least an MPEnvWrapper), - mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} - - Returns: ProMP wrapped gym env - - """ - _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) - - _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) - - _verify_dof(_env, mp_kwargs.get("num_dof")) - - return ProMPWrapper(_env, **mp_kwargs) - - -def make_dmp_env_helper(**kwargs): - """ - Helper function for registering a DMP gym environments. - Args: - **kwargs: expects at least the following: - { - "name": base_env_name, - "wrappers": list of wrappers (at least an MPEnvWrapper), - "mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP - } - - Returns: DMP wrapped gym env - - """ - seed = kwargs.pop("seed", None) - return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed, - mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) - - -def make_promp_env_helper(**kwargs): - """ - Helper function for registering ProMP gym environments. - This can also be used standalone for manually building a custom ProMP environment. - Args: - **kwargs: expects at least the following: - { - "name": base_env_name, - "wrappers": list of wrappers (at least an MPEnvWrapper), - "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int} - } - - Returns: ProMP wrapped gym env - - """ - seed = kwargs.pop("seed", None) - return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed, - mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs) + return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers, + black_box_wrapper_kwargs=black_box_kwargs, + traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs, + phase_kwargs=phase_kwargs, + basis_kwargs=basis_kwargs, **kwargs, seed=seed) def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): @@ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[ It can be found in the BaseMP class. Args: - mp_time_limit: max trajectory length of mp in seconds + mp_time_limit: max trajectory length of trajectory_generator in seconds env_time_limit: max trajectory length of DMC environment in seconds Returns: