restructuring
This commit is contained in:
		
							parent
							
								
									8fe6a83271
								
							
						
					
					
						commit
						02b8a65bab
					
				@ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper]
 | 
			
		||||
mp_kwargs = {...}
 | 
			
		||||
kwargs = {...}
 | 
			
		||||
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
 | 
			
		||||
# OR for a deterministic ProMP (other mp_kwargs are required):
 | 
			
		||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
 | 
			
		||||
# OR for a deterministic ProMP (other traj_gen_kwargs are required):
 | 
			
		||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
 | 
			
		||||
 | 
			
		||||
rewards = 0
 | 
			
		||||
obs = env.reset()
 | 
			
		||||
 | 
			
		||||
@ -346,7 +346,7 @@ for _v in _versions:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": f"alr_envs:{_v}",
 | 
			
		||||
            "wrappers": [classic_control.simple_reacher.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 2 if "long" not in _v.lower() else 5,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 2,
 | 
			
		||||
@ -386,7 +386,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "alr_envs:ViaPointReacher-v0",
 | 
			
		||||
        "wrappers": [classic_control.viapoint_reacher.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 5,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
@ -424,7 +424,7 @@ for _v in _versions:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": f"alr_envs:HoleReacher-{_v}",
 | 
			
		||||
            "wrappers": [classic_control.hole_reacher.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 5,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 2,
 | 
			
		||||
@ -467,7 +467,7 @@ for _v in _versions:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": f"alr_envs:{_v}",
 | 
			
		||||
            "wrappers": [mujoco.reacher.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 5 if "long" not in _v.lower() else 7,
 | 
			
		||||
                "num_basis": 2,
 | 
			
		||||
                "duration": 4,
 | 
			
		||||
 | 
			
		||||
@ -1,12 +1,13 @@
 | 
			
		||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
 | 
			
		||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
class NewMPWrapper(RawInterfaceWrapper):
 | 
			
		||||
 | 
			
		||||
    def set_active_obs(self):
 | 
			
		||||
    def get_context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
            [False] * 111, # ant has 111 dimensional observation space !!
 | 
			
		||||
            [True] # goal height
 | 
			
		||||
 | 
			
		||||
@ -1,15 +1,11 @@
 | 
			
		||||
from typing import Tuple, Union
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
 | 
			
		||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
 | 
			
		||||
    # def __init__(self, replanning_model):
 | 
			
		||||
    #     self.replanning_model = replanning_model
 | 
			
		||||
 | 
			
		||||
class NewMPWrapper(RawInterfaceWrapper):
 | 
			
		||||
    @property
 | 
			
		||||
    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        return self.env.sim.data.qpos[0:7].copy()
 | 
			
		||||
@ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        return self.env.sim.data.qvel[0:7].copy()
 | 
			
		||||
 | 
			
		||||
    def set_active_obs(self):
 | 
			
		||||
    def get_context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
            [False] * 7,  # cos
 | 
			
		||||
            [False] * 7,  # sin
 | 
			
		||||
@ -29,11 +25,6 @@ class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
            [False]  # env steps
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    def do_replanning(self, pos, vel, s, a, t, last_replan_step):
 | 
			
		||||
        return False
 | 
			
		||||
        # const = np.arange(0, 1000, 10)
 | 
			
		||||
        # return bool(self.replanning_model(s))
 | 
			
		||||
 | 
			
		||||
    def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
 | 
			
		||||
        if self.mp.learn_tau:
 | 
			
		||||
            self.env.env.release_step = action[0] / self.env.dt  # Tau value
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,9 @@
 | 
			
		||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
 | 
			
		||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
class NewMPWrapper(BlackBoxWrapper):
 | 
			
		||||
    @property
 | 
			
		||||
    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        return self.env.sim.data.qpos[3:6].copy()
 | 
			
		||||
@ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
    #     ])
 | 
			
		||||
 | 
			
		||||
    # Random x goal + random init pos
 | 
			
		||||
    def set_active_obs(self):
 | 
			
		||||
    def get_context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
                [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position
 | 
			
		||||
                [True] * 3,    # set to true if randomize initial pos
 | 
			
		||||
@ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NewHighCtxtMPWrapper(NewMPWrapper):
 | 
			
		||||
    def set_active_obs(self):
 | 
			
		||||
    def get_context_mask(self):
 | 
			
		||||
        return np.hstack([
 | 
			
		||||
            [False] * (2 + int(not self.env.exclude_current_positions_from_observation)),  # position
 | 
			
		||||
            [True] * 3,  # set to true if randomize initial pos
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,9 @@
 | 
			
		||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
 | 
			
		||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MPWrapper(EpisodicWrapper):
 | 
			
		||||
class MPWrapper(BlackBoxWrapper):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
@ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper):
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        return self.env.sim.data.qvel.flat[:self.env.n_links]
 | 
			
		||||
 | 
			
		||||
    def set_active_obs(self):
 | 
			
		||||
    def get_context_mask(self):
 | 
			
		||||
        return np.concatenate([
 | 
			
		||||
            [False] * self.env.n_links,  # cos
 | 
			
		||||
            [False] * self.env.n_links,  # sin
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.ball_in_cup.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -41,7 +41,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.ball_in_cup.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -65,7 +65,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.reacher.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -92,7 +92,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.reacher.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -117,7 +117,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.reacher.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -144,7 +144,7 @@ register(
 | 
			
		||||
        "time_limit": 20,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.reacher.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 20,
 | 
			
		||||
@ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks:
 | 
			
		||||
            "camera_id": 0,
 | 
			
		||||
            "episode_length": 1000,
 | 
			
		||||
            "wrappers": [suite.cartpole.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 1,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 10,
 | 
			
		||||
@ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks:
 | 
			
		||||
            "camera_id": 0,
 | 
			
		||||
            "episode_length": 1000,
 | 
			
		||||
            "wrappers": [suite.cartpole.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 1,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 10,
 | 
			
		||||
@ -230,7 +230,7 @@ register(
 | 
			
		||||
        "camera_id": 0,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.cartpole.TwoPolesMPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
@ -259,7 +259,7 @@ register(
 | 
			
		||||
        "camera_id": 0,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.cartpole.TwoPolesMPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
@ -286,7 +286,7 @@ register(
 | 
			
		||||
        "camera_id": 0,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.cartpole.ThreePolesMPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
@ -315,7 +315,7 @@ register(
 | 
			
		||||
        "camera_id": 0,
 | 
			
		||||
        "episode_length": 1000,
 | 
			
		||||
        "wrappers": [suite.cartpole.ThreePolesMPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
@ -342,7 +342,7 @@ register(
 | 
			
		||||
        # "time_limit": 1,
 | 
			
		||||
        "episode_length": 250,
 | 
			
		||||
        "wrappers": [manipulation.reach_site.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 9,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
@ -365,7 +365,7 @@ register(
 | 
			
		||||
        # "time_limit": 1,
 | 
			
		||||
        "episode_length": 250,
 | 
			
		||||
        "wrappers": [manipulation.reach_site.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 9,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 10,
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
        "learn_goal": True,  # learn the goal position (recommended)
 | 
			
		||||
        "alpha_phase": 2,
 | 
			
		||||
        "bandwidth_factor": 2,
 | 
			
		||||
        "policy_type": "motor",  # controller type, 'velocity', 'position', and 'motor' (torque control)
 | 
			
		||||
        "policy_type": "motor",  # tracking_controller type, 'velocity', 'position', and 'motor' (torque control)
 | 
			
		||||
        "weights_scale": 1,  # scaling of MP weights
 | 
			
		||||
        "goal_scale": 1,  # scaling of learned goal position
 | 
			
		||||
        "policy_kwargs": {  # only required for torque control/PD-Controller
 | 
			
		||||
@ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
        # "frame_skip": 1
 | 
			
		||||
    }
 | 
			
		||||
    env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
 | 
			
		||||
    # OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
 | 
			
		||||
    # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
 | 
			
		||||
    # OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples):
 | 
			
		||||
    # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
 | 
			
		||||
 | 
			
		||||
    # This renders the full MP trajectory
 | 
			
		||||
    # It is only required to call render() once in the beginning, which renders every consecutive trajectory.
 | 
			
		||||
 | 
			
		||||
@ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
        "width": 0.025,  # width of the basis functions
 | 
			
		||||
        "zero_start": True,  # start from current environment position if True
 | 
			
		||||
        "weights_scale": 1,  # scaling of MP weights
 | 
			
		||||
        "policy_type": "metaworld",  # custom controller type for metaworld environments
 | 
			
		||||
        "policy_type": "metaworld",  # custom tracking_controller type for metaworld environments
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
 | 
			
		||||
    # OR for a DMP (other mp_kwargs are required, see dmc_examples):
 | 
			
		||||
    # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
 | 
			
		||||
    # OR for a DMP (other traj_gen_kwargs are required, see dmc_examples):
 | 
			
		||||
    # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs)
 | 
			
		||||
 | 
			
		||||
    # This renders the full MP trajectory
 | 
			
		||||
    # It is only required to call render() once in the beginning, which renders every consecutive trajectory.
 | 
			
		||||
 | 
			
		||||
@ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=
 | 
			
		||||
    Returns:
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    # Changing the mp_kwargs is possible by providing them to gym.
 | 
			
		||||
    # Changing the traj_gen_kwargs is possible by providing them to gym.
 | 
			
		||||
    # E.g. here by providing way to many basis functions
 | 
			
		||||
    mp_kwargs = {
 | 
			
		||||
        "num_dof": 5,
 | 
			
		||||
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
 | 
			
		||||
    }
 | 
			
		||||
    env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
 | 
			
		||||
    # OR for a deterministic ProMP:
 | 
			
		||||
    # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
 | 
			
		||||
    # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs)
 | 
			
		||||
 | 
			
		||||
    if render:
 | 
			
		||||
        env.render(mode="human")
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ import alr_envs
 | 
			
		||||
def example_mp(env_name, seed=1):
 | 
			
		||||
    """
 | 
			
		||||
    Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
 | 
			
		||||
    For more information on motion primitive specific stuff, look at the mp examples.
 | 
			
		||||
    For more information on motion primitive specific stuff, look at the trajectory_generator examples.
 | 
			
		||||
    Args:
 | 
			
		||||
        env_name: ProMP env_id
 | 
			
		||||
        seed: seed
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env
 | 
			
		||||
 | 
			
		||||
def visualize(env):
 | 
			
		||||
    t = env.t
 | 
			
		||||
    pos_features = env.mp.basis_generator.basis(t)
 | 
			
		||||
    pos_features = env.trajectory_generator.basis_generator.basis(t)
 | 
			
		||||
    plt.plot(t, pos_features)
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ for _task in _goal_change_envs:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": _task,
 | 
			
		||||
            "wrappers": [goal_change_mp_wrapper.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 4,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 6.25,
 | 
			
		||||
@ -42,7 +42,7 @@ for _task in _object_change_envs:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": _task,
 | 
			
		||||
            "wrappers": [object_change_mp_wrapper.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 4,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 6.25,
 | 
			
		||||
@ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": _task,
 | 
			
		||||
            "wrappers": [goal_object_change_mp_wrapper.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 4,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 6.25,
 | 
			
		||||
@ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs:
 | 
			
		||||
        kwargs={
 | 
			
		||||
            "name": _task,
 | 
			
		||||
            "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
            "traj_gen_kwargs": {
 | 
			
		||||
                "num_dof": 4,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 6.25,
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
from abc import ABC
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
@ -7,77 +7,77 @@ from gym import spaces
 | 
			
		||||
from mp_pytorch.mp.mp_interfaces import MPInterface
 | 
			
		||||
 | 
			
		||||
from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
 | 
			
		||||
class BlackBoxWrapper(gym.ObservationWrapper, ABC):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 env: RawInterfaceWrapper,
 | 
			
		||||
                 trajectory_generator: MPInterface, tracking_controller: BaseController,
 | 
			
		||||
                 duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum):
 | 
			
		||||
        """
 | 
			
		||||
    Base class for movement primitive based gym.Wrapper implementations.
 | 
			
		||||
        gym.Wrapper for leveraging a black box approach with a trajectory generator.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            env: The (wrapped) environment this wrapper is applied on
 | 
			
		||||
        num_dof: Dimension of the action space of the wrapped env
 | 
			
		||||
        num_basis: Number of basis functions per dof
 | 
			
		||||
            trajectory_generator: Generates the full or partial trajectory
 | 
			
		||||
            tracking_controller: Translates the desired trajectory to raw action sequences
 | 
			
		||||
            duration: Length of the trajectory of the movement primitive in seconds
 | 
			
		||||
        controller: Type or object defining the policy that is used to generate action based on the trajectory
 | 
			
		||||
        weight_scale: Scaling parameter for the actions given to this wrapper
 | 
			
		||||
        render_mode: Equivalent to gym render mode
 | 
			
		||||
            verbose: level of detail for returned values in info dict.
 | 
			
		||||
            reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
 | 
			
		||||
                reward, default summation over all values.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            env: gym.Env,
 | 
			
		||||
            mp: MPInterface,
 | 
			
		||||
            controller: BaseController,
 | 
			
		||||
            duration: float,
 | 
			
		||||
            render_mode: str = None,
 | 
			
		||||
            verbose: int = 1,
 | 
			
		||||
            weight_scale: float = 1,
 | 
			
		||||
            sequencing=True,
 | 
			
		||||
            reward_aggregation=np.mean,
 | 
			
		||||
            ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.env = env
 | 
			
		||||
        try:
 | 
			
		||||
            self.dt = env.dt
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            raise AttributeError("step based environment needs to have a function 'dt' ")
 | 
			
		||||
        self.duration = duration
 | 
			
		||||
        self.traj_steps = int(duration / self.dt)
 | 
			
		||||
        self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
 | 
			
		||||
        # duration = self.env.max_episode_steps * self.dt
 | 
			
		||||
 | 
			
		||||
        self.mp = mp
 | 
			
		||||
        self.env = env
 | 
			
		||||
        self.controller = controller
 | 
			
		||||
        self.weight_scale = weight_scale
 | 
			
		||||
 | 
			
		||||
        # rendering
 | 
			
		||||
        self.render_mode = render_mode
 | 
			
		||||
        self.render_kwargs = {}
 | 
			
		||||
        # trajectory generation
 | 
			
		||||
        self.trajectory_generator = trajectory_generator
 | 
			
		||||
        self.tracking_controller = tracking_controller
 | 
			
		||||
        # self.weight_scale = weight_scale
 | 
			
		||||
        self.time_steps = np.linspace(0, self.duration, self.traj_steps)
 | 
			
		||||
        self.mp.set_mp_times(self.time_steps)
 | 
			
		||||
        # self.mp.set_mp_duration(self.time_steps, dt)
 | 
			
		||||
        # action_bounds = np.inf * np.ones((np.prod(self.mp.num_params)))
 | 
			
		||||
        self.mp_action_space = self.get_mp_action_space()
 | 
			
		||||
        self.trajectory_generator.set_mp_times(self.time_steps)
 | 
			
		||||
        # self.trajectory_generator.set_mp_duration(self.time_steps, dt)
 | 
			
		||||
        # action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
 | 
			
		||||
        self.reward_aggregation = reward_aggregation
 | 
			
		||||
 | 
			
		||||
        # spaces
 | 
			
		||||
        self.mp_action_space = self.get_mp_action_space()
 | 
			
		||||
        self.action_space = self.get_action_space()
 | 
			
		||||
        self.active_obs = self.set_active_obs()
 | 
			
		||||
        self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs],
 | 
			
		||||
                                            high=self.env.observation_space.high[self.active_obs],
 | 
			
		||||
        self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask],
 | 
			
		||||
                                            high=self.env.observation_space.high[self.env.context_mask],
 | 
			
		||||
                                            dtype=self.env.observation_space.dtype)
 | 
			
		||||
 | 
			
		||||
        # rendering
 | 
			
		||||
        self.render_mode = None
 | 
			
		||||
        self.render_kwargs = {}
 | 
			
		||||
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dt(self):
 | 
			
		||||
        return self.env.dt
 | 
			
		||||
 | 
			
		||||
    def observation(self, observation):
 | 
			
		||||
        return observation[self.env.context_mask]
 | 
			
		||||
 | 
			
		||||
    def get_trajectory(self, action: np.ndarray) -> Tuple:
 | 
			
		||||
        # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
 | 
			
		||||
        #  the beginning of the array.
 | 
			
		||||
        ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay)
 | 
			
		||||
        scaled_mp_params = action.copy()
 | 
			
		||||
        scaled_mp_params[ignore_indices:] *= self.weight_scale
 | 
			
		||||
        self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high))
 | 
			
		||||
        self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel)
 | 
			
		||||
        traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True)
 | 
			
		||||
        # ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
 | 
			
		||||
        # scaled_mp_params = action.copy()
 | 
			
		||||
        # scaled_mp_params[ignore_indices:] *= self.weight_scale
 | 
			
		||||
 | 
			
		||||
        clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
 | 
			
		||||
        self.trajectory_generator.set_params(clipped_params)
 | 
			
		||||
        self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
 | 
			
		||||
                                                          bc_vel=self.current_vel)
 | 
			
		||||
        traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
 | 
			
		||||
        trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
 | 
			
		||||
 | 
			
		||||
        trajectory = trajectory_tensor.numpy()
 | 
			
		||||
@ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
 | 
			
		||||
        # TODO: Do we need this or does mp_pytorch have this?
 | 
			
		||||
        if self.post_traj_steps > 0:
 | 
			
		||||
            trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
 | 
			
		||||
            velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))])
 | 
			
		||||
            velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))])
 | 
			
		||||
 | 
			
		||||
        return trajectory, velocity
 | 
			
		||||
 | 
			
		||||
    def get_mp_action_space(self):
 | 
			
		||||
        """This function can be used to set up an individual space for the parameters of the mp."""
 | 
			
		||||
        min_action_bounds, max_action_bounds = self.mp.get_param_bounds()
 | 
			
		||||
        """This function can be used to set up an individual space for the parameters of the trajectory_generator."""
 | 
			
		||||
        min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds()
 | 
			
		||||
        mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
 | 
			
		||||
                                         dtype=np.float32)
 | 
			
		||||
        return mp_action_space
 | 
			
		||||
@ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return self.get_mp_action_space()
 | 
			
		||||
 | 
			
		||||
    def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
 | 
			
		||||
        """
 | 
			
		||||
        Used to extract the parameters for the motion primitive and other parameters from an action array which might
 | 
			
		||||
        include other actions like ball releasing time for the beer pong environment.
 | 
			
		||||
        This only needs to be overwritten if the action space is modified.
 | 
			
		||||
        Args:
 | 
			
		||||
            action: a vector instance of the whole action space, includes mp parameters and additional parameters if
 | 
			
		||||
            specified, else only mp parameters
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple: mp_arguments and other arguments
 | 
			
		||||
        """
 | 
			
		||||
        return action, None
 | 
			
		||||
 | 
			
		||||
    def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
 | 
			
		||||
        np.ndarray]:
 | 
			
		||||
        """
 | 
			
		||||
        This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
 | 
			
		||||
        Beerpong env. The parameters used should not be part of the motion primitive parameters.
 | 
			
		||||
        Returns step_action by default, can be overwritten in individual mp_wrappers.
 | 
			
		||||
        Args:
 | 
			
		||||
            t: the current time step of the episode
 | 
			
		||||
            env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback
 | 
			
		||||
            (e.g. ball release time in Beer Pong)
 | 
			
		||||
            step_action: the current step-based action
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            modified step action
 | 
			
		||||
        """
 | 
			
		||||
        return step_action
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def set_active_obs(self) -> np.ndarray:
 | 
			
		||||
        """
 | 
			
		||||
        This function defines the contexts. The contexts are defined as specific observations.
 | 
			
		||||
        Returns:
 | 
			
		||||
            boolearn array representing the indices of the observations
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        return np.ones(self.env.observation_space.shape[0], dtype=bool)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        """
 | 
			
		||||
            Returns the current position of the action/control dimension.
 | 
			
		||||
            The dimensionality has to match the action/control dimension.
 | 
			
		||||
            This is not required when exclusively using velocity control,
 | 
			
		||||
            it should, however, be implemented regardless.
 | 
			
		||||
            E.g. The joint positions that are directly or indirectly controlled by the action.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        """
 | 
			
		||||
            Returns the current velocity of the action/control dimension.
 | 
			
		||||
            The dimensionality has to match the action/control dimension.
 | 
			
		||||
            This is not required when exclusively using position control,
 | 
			
		||||
            it should, however, be implemented regardless.
 | 
			
		||||
            E.g. The joint velocities that are directly or indirectly controlled by the action.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def step(self, action: np.ndarray):
 | 
			
		||||
        """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
 | 
			
		||||
        # TODO: Think about sequencing
 | 
			
		||||
@ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
 | 
			
		||||
 | 
			
		||||
        # TODO
 | 
			
		||||
        # self.time_steps = np.linspace(0, learned_duration, self.traj_steps)
 | 
			
		||||
        # self.mp.set_mp_times(self.time_steps)
 | 
			
		||||
        # self.trajectory_generator.set_mp_times(self.time_steps)
 | 
			
		||||
 | 
			
		||||
        trajectory_length = len(trajectory)
 | 
			
		||||
        rewards = np.zeros(shape=(trajectory_length,))
 | 
			
		||||
        if self.verbose >= 2:
 | 
			
		||||
            actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
 | 
			
		||||
            observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
 | 
			
		||||
                                    dtype=self.env.observation_space.dtype)
 | 
			
		||||
            rewards = np.zeros(shape=(trajectory_length,))
 | 
			
		||||
        trajectory_return = 0
 | 
			
		||||
 | 
			
		||||
        infos = dict()
 | 
			
		||||
        done = False
 | 
			
		||||
 | 
			
		||||
        for t, pos_vel in enumerate(zip(trajectory, velocity)):
 | 
			
		||||
            step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel)
 | 
			
		||||
            step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos,
 | 
			
		||||
                                                              self.current_vel)
 | 
			
		||||
            step_action = self._step_callback(t, env_spec_params, step_action)  # include possible callback info
 | 
			
		||||
            c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
 | 
			
		||||
            # print('step/clipped action ratio: ', step_action/c_action)
 | 
			
		||||
            obs, c_reward, done, info = self.env.step(c_action)
 | 
			
		||||
            rewards[t] = c_reward
 | 
			
		||||
 | 
			
		||||
            if self.verbose >= 2:
 | 
			
		||||
                actions[t, :] = c_action
 | 
			
		||||
                rewards[t] = c_reward
 | 
			
		||||
                observations[t, :] = obs
 | 
			
		||||
            trajectory_return += c_reward
 | 
			
		||||
 | 
			
		||||
            for k, v in info.items():
 | 
			
		||||
                elems = infos.get(k, [None] * trajectory_length)
 | 
			
		||||
                elems[t] = v
 | 
			
		||||
                infos[k] = elems
 | 
			
		||||
            # infos['step_infos'].append(info)
 | 
			
		||||
            if self.render_mode:
 | 
			
		||||
 | 
			
		||||
            if self.render_mode is not None:
 | 
			
		||||
                self.render(mode=self.render_mode, **self.render_kwargs)
 | 
			
		||||
            if done or do_replanning(kwargs):
 | 
			
		||||
 | 
			
		||||
            if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        infos.update({k: v[:t + 1] for k, v in infos.items()})
 | 
			
		||||
 | 
			
		||||
        if self.verbose >= 2:
 | 
			
		||||
            infos['trajectory'] = trajectory
 | 
			
		||||
            infos['step_actions'] = actions[:t + 1]
 | 
			
		||||
            infos['step_observations'] = observations[:t + 1]
 | 
			
		||||
            infos['step_rewards'] = rewards[:t + 1]
 | 
			
		||||
 | 
			
		||||
        infos['trajectory_length'] = t + 1
 | 
			
		||||
        done = True
 | 
			
		||||
        trajectory_return = self.reward_aggregation(rewards[:t + 1])
 | 
			
		||||
        return self.get_observation_from_step(obs), trajectory_return, done, infos
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
@ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
class MetaWorldController(BaseController):
 | 
			
		||||
    """
 | 
			
		||||
    A Metaworld Controller. Using position and velocity information from a provided environment,
 | 
			
		||||
    the controller calculates a response based on the desired position and velocity.
 | 
			
		||||
    Unlike the other Controllers, this is a special controller for MetaWorld environments.
 | 
			
		||||
    the tracking_controller calculates a response based on the desired position and velocity.
 | 
			
		||||
    Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
 | 
			
		||||
    They use a position delta for the xyz coordinates and a raw position for the gripper opening.
 | 
			
		||||
 | 
			
		||||
    :param env: A position environment
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
class PDController(BaseController):
 | 
			
		||||
    """
 | 
			
		||||
    A PD-Controller. Using position and velocity information from a provided environment,
 | 
			
		||||
    the controller calculates a response based on the desired position and velocity
 | 
			
		||||
    the tracking_controller calculates a response based on the desired position and velocity
 | 
			
		||||
 | 
			
		||||
    :param env: A position environment
 | 
			
		||||
    :param p_gains: Factors for the proportional gains
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
 | 
			
		||||
class PosController(BaseController):
 | 
			
		||||
    """
 | 
			
		||||
    A Position Controller. The controller calculates a response only based on the desired position.
 | 
			
		||||
    A Position Controller. The tracking_controller calculates a response only based on the desired position.
 | 
			
		||||
    """
 | 
			
		||||
    def get_action(self, des_pos, des_vel, c_pos, c_vel):
 | 
			
		||||
        return des_pos
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
 | 
			
		||||
class VelController(BaseController):
 | 
			
		||||
    """
 | 
			
		||||
    A Velocity Controller. The controller calculates a response only based on the desired velocity.
 | 
			
		||||
    A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity.
 | 
			
		||||
    """
 | 
			
		||||
    def get_action(self, des_pos, des_vel, c_pos, c_vel):
 | 
			
		||||
        return des_vel
 | 
			
		||||
 | 
			
		||||
@ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator
 | 
			
		||||
ALL_TYPES = ["promp", "dmp", "idmp"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_movement_primitive(
 | 
			
		||||
        movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
 | 
			
		||||
def get_trajectory_generator(
 | 
			
		||||
        trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
    movement_primitives_type = movement_primitives_type.lower()
 | 
			
		||||
    if movement_primitives_type == "promp":
 | 
			
		||||
    trajectory_generator_type = trajectory_generator_type.lower()
 | 
			
		||||
    if trajectory_generator_type == "promp":
 | 
			
		||||
        return ProMP(basis_generator, action_dim, **kwargs)
 | 
			
		||||
    elif movement_primitives_type == "dmp":
 | 
			
		||||
    elif trajectory_generator_type == "dmp":
 | 
			
		||||
        return DMP(basis_generator, action_dim, **kwargs)
 | 
			
		||||
    elif movement_primitives_type == 'idmp':
 | 
			
		||||
    elif trajectory_generator_type == 'idmp':
 | 
			
		||||
        return IDMP(basis_generator, action_dim, **kwargs)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, "
 | 
			
		||||
        raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
 | 
			
		||||
                         f"please choose one of {ALL_TYPES}.")
 | 
			
		||||
							
								
								
									
										88
									
								
								alr_envs/mp/raw_interface_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								alr_envs/mp/raw_interface_wrapper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,88 @@
 | 
			
		||||
from typing import Union, Tuple
 | 
			
		||||
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RawInterfaceWrapper(gym.Wrapper):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def context_mask(self) -> np.ndarray:
 | 
			
		||||
        """
 | 
			
		||||
        This function defines the contexts. The contexts are defined as specific observations.
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool array representing the indices of the observations
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        return np.ones(self.env.observation_space.shape[0], dtype=bool)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        """
 | 
			
		||||
            Returns the current position of the action/control dimension.
 | 
			
		||||
            The dimensionality has to match the action/control dimension.
 | 
			
		||||
            This is not required when exclusively using velocity control,
 | 
			
		||||
            it should, however, be implemented regardless.
 | 
			
		||||
            E.g. The joint positions that are directly or indirectly controlled by the action.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
 | 
			
		||||
        """
 | 
			
		||||
            Returns the current velocity of the action/control dimension.
 | 
			
		||||
            The dimensionality has to match the action/control dimension.
 | 
			
		||||
            This is not required when exclusively using position control,
 | 
			
		||||
            it should, however, be implemented regardless.
 | 
			
		||||
            E.g. The joint velocities that are directly or indirectly controlled by the action.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def dt(self) -> float:
 | 
			
		||||
        """
 | 
			
		||||
        Control frequency of the environment
 | 
			
		||||
        Returns: float
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def do_replanning(self, pos, vel, s, a, t):
 | 
			
		||||
        # return t % 100 == 0
 | 
			
		||||
        # return bool(self.replanning_model(s))
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
 | 
			
		||||
        """
 | 
			
		||||
        Used to extract the parameters for the motion primitive and other parameters from an action array which might
 | 
			
		||||
        include other actions like ball releasing time for the beer pong environment.
 | 
			
		||||
        This only needs to be overwritten if the action space is modified.
 | 
			
		||||
        Args:
 | 
			
		||||
            action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if
 | 
			
		||||
            specified, else only trajectory_generator parameters
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Tuple: mp_arguments and other arguments
 | 
			
		||||
        """
 | 
			
		||||
        return action, None
 | 
			
		||||
 | 
			
		||||
    def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
 | 
			
		||||
        np.ndarray]:
 | 
			
		||||
        """
 | 
			
		||||
        This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
 | 
			
		||||
        Beerpong env. The parameters used should not be part of the motion primitive parameters.
 | 
			
		||||
        Returns step_action by default, can be overwritten in individual mp_wrappers.
 | 
			
		||||
        Args:
 | 
			
		||||
            t: the current time step of the episode
 | 
			
		||||
            env_spec_params: the environment specific parameter, as defined in function _episode_callback
 | 
			
		||||
            (e.g. ball release time in Beer Pong)
 | 
			
		||||
            step_action: the current step-based action
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            modified step action
 | 
			
		||||
        """
 | 
			
		||||
        return step_action
 | 
			
		||||
@ -21,7 +21,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "alr_envs:MountainCarContinuous-v1",
 | 
			
		||||
        "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 4,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
@ -43,7 +43,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.classic_control:MountainCarContinuous-v0",
 | 
			
		||||
        "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 1,
 | 
			
		||||
            "num_basis": 4,
 | 
			
		||||
            "duration": 19.98,
 | 
			
		||||
@ -65,7 +65,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.mujoco:Reacher-v2",
 | 
			
		||||
        "wrappers": [mujoco.reacher_v2.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 2,
 | 
			
		||||
            "num_basis": 6,
 | 
			
		||||
            "duration": 1,
 | 
			
		||||
@ -87,7 +87,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.robotics:FetchSlideDense-v1",
 | 
			
		||||
        "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 4,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
@ -105,7 +105,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.robotics:FetchSlide-v1",
 | 
			
		||||
        "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 4,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
@ -123,7 +123,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.robotics:FetchReachDense-v1",
 | 
			
		||||
        "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 4,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
@ -141,7 +141,7 @@ register(
 | 
			
		||||
    kwargs={
 | 
			
		||||
        "name": "gym.envs.robotics:FetchReach-v1",
 | 
			
		||||
        "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
 | 
			
		||||
        "mp_kwargs": {
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "num_dof": 4,
 | 
			
		||||
            "num_basis": 5,
 | 
			
		||||
            "duration": 2,
 | 
			
		||||
 | 
			
		||||
@ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping
 | 
			
		||||
import gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gym.envs.registration import EnvSpec
 | 
			
		||||
 | 
			
		||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
 | 
			
		||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
 | 
			
		||||
from mp_pytorch import MPInterface
 | 
			
		||||
 | 
			
		||||
from alr_envs.mp.basis_generator_factory import get_basis_generator
 | 
			
		||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
 | 
			
		||||
from alr_envs.mp.controllers.base_controller import BaseController
 | 
			
		||||
from alr_envs.mp.controllers.controller_factory import get_controller
 | 
			
		||||
from alr_envs.mp.mp_factory import get_movement_primitive
 | 
			
		||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
 | 
			
		||||
from alr_envs.mp.mp_factory import get_trajectory_generator
 | 
			
		||||
from alr_envs.mp.phase_generator_factory import get_phase_generator
 | 
			
		||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
 | 
			
		||||
@ -100,8 +98,7 @@ def make(env_id: str, seed, **kwargs):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_wrapped_env(
 | 
			
		||||
        env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController,
 | 
			
		||||
        ep_wrapper_kwargs: Mapping, seed=1, **kwargs
 | 
			
		||||
        env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for creating a wrapped gym environment using MPs.
 | 
			
		||||
@ -118,73 +115,74 @@ def _make_wrapped_env(
 | 
			
		||||
    """
 | 
			
		||||
    # _env = gym.make(env_id)
 | 
			
		||||
    _env = make(env_id, seed, **kwargs)
 | 
			
		||||
    has_episodic_wrapper = False
 | 
			
		||||
    has_black_box_wrapper = False
 | 
			
		||||
    for w in wrappers:
 | 
			
		||||
        # only wrap the environment if not EpisodicWrapper, e.g. for vision
 | 
			
		||||
        if not issubclass(w, EpisodicWrapper):
 | 
			
		||||
        # only wrap the environment if not BlackBoxWrapper, e.g. for vision
 | 
			
		||||
        if issubclass(w, RawInterfaceWrapper):
 | 
			
		||||
            has_black_box_wrapper = True
 | 
			
		||||
        _env = w(_env)
 | 
			
		||||
        else:  # if EpisodicWrapper, use specific constructor
 | 
			
		||||
            has_episodic_wrapper = True
 | 
			
		||||
            _env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs)
 | 
			
		||||
    if not has_episodic_wrapper:
 | 
			
		||||
        raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.")
 | 
			
		||||
    if not has_black_box_wrapper:
 | 
			
		||||
        raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.")
 | 
			
		||||
    return _env
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_mp_from_kwargs(
 | 
			
		||||
        env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping,
 | 
			
		||||
def make_bb_env(
 | 
			
		||||
        env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
 | 
			
		||||
        controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1,
 | 
			
		||||
        sequenced=False, **kwargs
 | 
			
		||||
        ):
 | 
			
		||||
        sequenced=False, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    This can also be used standalone for manually building a custom DMP environment.
 | 
			
		||||
    Args:
 | 
			
		||||
        ep_wrapper_kwargs:
 | 
			
		||||
        basis_kwargs:
 | 
			
		||||
        phase_kwargs:
 | 
			
		||||
        controller_kwargs:
 | 
			
		||||
        black_box_wrapper_kwargs: kwargs for the black-box wrapper
 | 
			
		||||
        basis_kwargs: kwargs for the basis generator
 | 
			
		||||
        phase_kwargs: kwargs for the phase generator
 | 
			
		||||
        controller_kwargs: kwargs for the tracking controller
 | 
			
		||||
        env_id: base_env_name,
 | 
			
		||||
        wrappers: list of wrappers (at least an EpisodicWrapper),
 | 
			
		||||
        wrappers: list of wrappers (at least an BlackBoxWrapper),
 | 
			
		||||
        seed: seed of environment
 | 
			
		||||
        sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory,
 | 
			
		||||
                this behavior is much closer to step based learning.
 | 
			
		||||
        mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
 | 
			
		||||
        traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
 | 
			
		||||
 | 
			
		||||
    Returns: DMP wrapped gym env
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
			
		||||
    dummy_env = make(env_id, seed)
 | 
			
		||||
    if ep_wrapper_kwargs.get('duration', None) is None:
 | 
			
		||||
        ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt
 | 
			
		||||
    _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
			
		||||
    _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
 | 
			
		||||
 | 
			
		||||
    if black_box_wrapper_kwargs.get('duration', None) is None:
 | 
			
		||||
        black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt
 | 
			
		||||
    if phase_kwargs.get('tau', None) is None:
 | 
			
		||||
        phase_kwargs['tau'] = ep_wrapper_kwargs['duration']
 | 
			
		||||
    mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item())
 | 
			
		||||
        phase_kwargs['tau'] = black_box_wrapper_kwargs['duration']
 | 
			
		||||
    traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item())
 | 
			
		||||
 | 
			
		||||
    phase_gen = get_phase_generator(**phase_kwargs)
 | 
			
		||||
    basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
 | 
			
		||||
    controller = get_controller(**controller_kwargs)
 | 
			
		||||
    mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs)
 | 
			
		||||
    _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller,
 | 
			
		||||
                             ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs)
 | 
			
		||||
    return _env
 | 
			
		||||
    traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs)
 | 
			
		||||
 | 
			
		||||
    bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller,
 | 
			
		||||
                             **black_box_wrapper_kwargs)
 | 
			
		||||
 | 
			
		||||
    return bb_env
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_mp_env_helper(**kwargs):
 | 
			
		||||
def make_bb_env_helper(**kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for registering a DMP gym environments.
 | 
			
		||||
    Helper function for registering a black box gym environment.
 | 
			
		||||
    Args:
 | 
			
		||||
        **kwargs: expects at least the following:
 | 
			
		||||
        {
 | 
			
		||||
        "name": base environment name.
 | 
			
		||||
        "wrappers": list of wrappers (at least an EpisodicWrapper is required),
 | 
			
		||||
        "movement_primitives_kwargs": {
 | 
			
		||||
            "movement_primitives_type": type_of_your_movement_primitive,
 | 
			
		||||
        "wrappers": list of wrappers (at least an BlackBoxWrapper is required),
 | 
			
		||||
        "traj_gen_kwargs": {
 | 
			
		||||
            "trajectory_generator_type": type_of_your_movement_primitive,
 | 
			
		||||
            non default arguments for the movement primitive instance
 | 
			
		||||
            ...
 | 
			
		||||
            }
 | 
			
		||||
        "controller_kwargs": {
 | 
			
		||||
            "controller_type": type_of_your_controller,
 | 
			
		||||
            non default arguments for the controller instance
 | 
			
		||||
            non default arguments for the tracking_controller instance
 | 
			
		||||
            ...
 | 
			
		||||
            },
 | 
			
		||||
        "basis_generator_kwargs": {
 | 
			
		||||
@ -205,97 +203,19 @@ def make_mp_env_helper(**kwargs):
 | 
			
		||||
    seed = kwargs.pop("seed", None)
 | 
			
		||||
    wrappers = kwargs.pop("wrappers")
 | 
			
		||||
 | 
			
		||||
    mp_kwargs = kwargs.pop("movement_primitives_kwargs")
 | 
			
		||||
    ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs')
 | 
			
		||||
    contr_kwargs = kwargs.pop("controller_kwargs")
 | 
			
		||||
    phase_kwargs = kwargs.pop("phase_generator_kwargs")
 | 
			
		||||
    basis_kwargs = kwargs.pop("basis_generator_kwargs")
 | 
			
		||||
    traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
 | 
			
		||||
    black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {})
 | 
			
		||||
    contr_kwargs = kwargs.pop("controller_kwargs", {})
 | 
			
		||||
    phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
 | 
			
		||||
    basis_kwargs = kwargs.pop("basis_generator_kwargs", {})
 | 
			
		||||
 | 
			
		||||
    return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs,
 | 
			
		||||
                               mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs,
 | 
			
		||||
    return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers,
 | 
			
		||||
                       black_box_wrapper_kwargs=black_box_kwargs,
 | 
			
		||||
                       traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs,
 | 
			
		||||
                       phase_kwargs=phase_kwargs,
 | 
			
		||||
                       basis_kwargs=basis_kwargs, **kwargs, seed=seed)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    This can also be used standalone for manually building a custom DMP environment.
 | 
			
		||||
    Args:
 | 
			
		||||
        env_id: base_env_name,
 | 
			
		||||
        wrappers: list of wrappers (at least an MPEnvWrapper),
 | 
			
		||||
        seed: seed of environment
 | 
			
		||||
        mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
 | 
			
		||||
 | 
			
		||||
    Returns: DMP wrapped gym env
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
			
		||||
 | 
			
		||||
    _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
 | 
			
		||||
 | 
			
		||||
    _verify_dof(_env, mp_kwargs.get("num_dof"))
 | 
			
		||||
 | 
			
		||||
    return DmpWrapper(_env, **mp_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    This can also be used standalone for manually building a custom ProMP environment.
 | 
			
		||||
    Args:
 | 
			
		||||
        env_id: base_env_name,
 | 
			
		||||
        wrappers: list of wrappers (at least an MPEnvWrapper),
 | 
			
		||||
        mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
 | 
			
		||||
 | 
			
		||||
    Returns: ProMP wrapped gym env
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    _verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
 | 
			
		||||
 | 
			
		||||
    _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
 | 
			
		||||
 | 
			
		||||
    _verify_dof(_env, mp_kwargs.get("num_dof"))
 | 
			
		||||
 | 
			
		||||
    return ProMPWrapper(_env, **mp_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_dmp_env_helper(**kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for registering a DMP gym environments.
 | 
			
		||||
    Args:
 | 
			
		||||
        **kwargs: expects at least the following:
 | 
			
		||||
        {
 | 
			
		||||
        "name": base_env_name,
 | 
			
		||||
        "wrappers": list of wrappers (at least an MPEnvWrapper),
 | 
			
		||||
        "mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    Returns: DMP wrapped gym env
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    seed = kwargs.pop("seed", None)
 | 
			
		||||
    return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
 | 
			
		||||
                        mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_promp_env_helper(**kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function for registering ProMP gym environments.
 | 
			
		||||
    This can also be used standalone for manually building a custom ProMP environment.
 | 
			
		||||
    Args:
 | 
			
		||||
        **kwargs: expects at least the following:
 | 
			
		||||
        {
 | 
			
		||||
        "name": base_env_name,
 | 
			
		||||
        "wrappers": list of wrappers (at least an MPEnvWrapper),
 | 
			
		||||
        "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    Returns: ProMP wrapped gym env
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    seed = kwargs.pop("seed", None)
 | 
			
		||||
    return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
 | 
			
		||||
                          mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
 | 
			
		||||
    """
 | 
			
		||||
    When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
 | 
			
		||||
@ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[
 | 
			
		||||
    It can be found in the BaseMP class.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        mp_time_limit: max trajectory length of mp in seconds
 | 
			
		||||
        mp_time_limit: max trajectory length of trajectory_generator in seconds
 | 
			
		||||
        env_time_limit: max trajectory length of DMC environment in seconds
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user