restructuring

This commit is contained in:
Fabian 2022-06-29 09:37:18 +02:00
parent 8fe6a83271
commit 02b8a65bab
23 changed files with 280 additions and 339 deletions

View File

@ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper]
mp_kwargs = {...} mp_kwargs = {...}
kwargs = {...} kwargs = {...}
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs) env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required): # OR for a deterministic ProMP (other traj_gen_kwargs are required):
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
rewards = 0 rewards = 0
obs = env.reset() obs = env.reset()

View File

@ -346,7 +346,7 @@ for _v in _versions:
kwargs={ kwargs={
"name": f"alr_envs:{_v}", "name": f"alr_envs:{_v}",
"wrappers": [classic_control.simple_reacher.MPWrapper], "wrappers": [classic_control.simple_reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2 if "long" not in _v.lower() else 5, "num_dof": 2 if "long" not in _v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -386,7 +386,7 @@ register(
kwargs={ kwargs={
"name": "alr_envs:ViaPointReacher-v0", "name": "alr_envs:ViaPointReacher-v0",
"wrappers": [classic_control.viapoint_reacher.MPWrapper], "wrappers": [classic_control.viapoint_reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -424,7 +424,7 @@ for _v in _versions:
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{_v}", "name": f"alr_envs:HoleReacher-{_v}",
"wrappers": [classic_control.hole_reacher.MPWrapper], "wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -467,7 +467,7 @@ for _v in _versions:
kwargs={ kwargs={
"name": f"alr_envs:{_v}", "name": f"alr_envs:{_v}",
"wrappers": [mujoco.reacher.MPWrapper], "wrappers": [mujoco.reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 5 if "long" not in _v.lower() else 7, "num_dof": 5 if "long" not in _v.lower() else 7,
"num_basis": 2, "num_basis": 2,
"duration": 4, "duration": 4,

View File

@ -1,12 +1,13 @@
from alr_envs.mp.episodic_wrapper import EpisodicWrapper from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from typing import Union, Tuple from typing import Union, Tuple
import numpy as np import numpy as np
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(EpisodicWrapper): class NewMPWrapper(RawInterfaceWrapper):
def set_active_obs(self): def get_context_mask(self):
return np.hstack([ return np.hstack([
[False] * 111, # ant has 111 dimensional observation space !! [False] * 111, # ant has 111 dimensional observation space !!
[True] # goal height [True] # goal height

View File

@ -1,15 +1,11 @@
from typing import Tuple, Union from typing import Union, Tuple
import numpy as np import numpy as np
from alr_envs.mp.episodic_wrapper import EpisodicWrapper from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class NewMPWrapper(EpisodicWrapper): class NewMPWrapper(RawInterfaceWrapper):
# def __init__(self, replanning_model):
# self.replanning_model = replanning_model
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[0:7].copy() return self.env.sim.data.qpos[0:7].copy()
@ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel[0:7].copy() return self.env.sim.data.qvel[0:7].copy()
def set_active_obs(self): def get_context_mask(self):
return np.hstack([ return np.hstack([
[False] * 7, # cos [False] * 7, # cos
[False] * 7, # sin [False] * 7, # sin
@ -27,12 +23,7 @@ class NewMPWrapper(EpisodicWrapper):
[False] * 3, # cup_goal_diff_top [False] * 3, # cup_goal_diff_top
[True] * 2, # xy position of cup [True] * 2, # xy position of cup
[False] # env steps [False] # env steps
]) ])
def do_replanning(self, pos, vel, s, a, t, last_replan_step):
return False
# const = np.arange(0, 1000, 10)
# return bool(self.replanning_model(s))
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]: def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
if self.mp.learn_tau: if self.mp.learn_tau:

View File

@ -1,9 +1,9 @@
from alr_envs.mp.episodic_wrapper import EpisodicWrapper from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from typing import Union, Tuple from typing import Union, Tuple
import numpy as np import numpy as np
class NewMPWrapper(EpisodicWrapper): class NewMPWrapper(BlackBoxWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qpos[3:6].copy() return self.env.sim.data.qpos[3:6].copy()
@ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper):
# ]) # ])
# Random x goal + random init pos # Random x goal + random init pos
def set_active_obs(self): def get_context_mask(self):
return np.hstack([ return np.hstack([
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
[True] * 3, # set to true if randomize initial pos [True] * 3, # set to true if randomize initial pos
@ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper):
class NewHighCtxtMPWrapper(NewMPWrapper): class NewHighCtxtMPWrapper(NewMPWrapper):
def set_active_obs(self): def get_context_mask(self):
return np.hstack([ return np.hstack([
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position [False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
[True] * 3, # set to true if randomize initial pos [True] * 3, # set to true if randomize initial pos

View File

@ -149,4 +149,4 @@ if __name__ == '__main__':
if d: if d:
env.reset() env.reset()
env.close() env.close()

View File

@ -1,9 +1,9 @@
from alr_envs.mp.episodic_wrapper import EpisodicWrapper from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from typing import Union, Tuple from typing import Union, Tuple
import numpy as np import numpy as np
class MPWrapper(EpisodicWrapper): class MPWrapper(BlackBoxWrapper):
@property @property
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]: def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
@ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.env.sim.data.qvel.flat[:self.env.n_links] return self.env.sim.data.qvel.flat[:self.env.n_links]
def set_active_obs(self): def get_context_mask(self):
return np.concatenate([ return np.concatenate([
[False] * self.env.n_links, # cos [False] * self.env.n_links, # cos
[False] * self.env.n_links, # sin [False] * self.env.n_links, # sin

View File

@ -15,7 +15,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.ball_in_cup.MPWrapper], "wrappers": [suite.ball_in_cup.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -41,7 +41,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.ball_in_cup.MPWrapper], "wrappers": [suite.ball_in_cup.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -65,7 +65,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.reacher.MPWrapper], "wrappers": [suite.reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -92,7 +92,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.reacher.MPWrapper], "wrappers": [suite.reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -117,7 +117,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.reacher.MPWrapper], "wrappers": [suite.reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -144,7 +144,7 @@ register(
"time_limit": 20, "time_limit": 20,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.reacher.MPWrapper], "wrappers": [suite.reacher.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 20,
@ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks:
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.MPWrapper], "wrappers": [suite.cartpole.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks:
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.MPWrapper], "wrappers": [suite.cartpole.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -230,7 +230,7 @@ register(
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.TwoPolesMPWrapper], "wrappers": [suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -259,7 +259,7 @@ register(
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.TwoPolesMPWrapper], "wrappers": [suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -286,7 +286,7 @@ register(
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.ThreePolesMPWrapper], "wrappers": [suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -315,7 +315,7 @@ register(
"camera_id": 0, "camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [suite.cartpole.ThreePolesMPWrapper], "wrappers": [suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -342,7 +342,7 @@ register(
# "time_limit": 1, # "time_limit": 1,
"episode_length": 250, "episode_length": 250,
"wrappers": [manipulation.reach_site.MPWrapper], "wrappers": [manipulation.reach_site.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 9, "num_dof": 9,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,
@ -365,7 +365,7 @@ register(
# "time_limit": 1, # "time_limit": 1,
"episode_length": 250, "episode_length": 250,
"wrappers": [manipulation.reach_site.MPWrapper], "wrappers": [manipulation.reach_site.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 9, "num_dof": 9,
"num_basis": 5, "num_basis": 5,
"duration": 10, "duration": 10,

View File

@ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"learn_goal": True, # learn the goal position (recommended) "learn_goal": True, # learn the goal position (recommended)
"alpha_phase": 2, "alpha_phase": 2,
"bandwidth_factor": 2, "bandwidth_factor": 2,
"policy_type": "motor", # controller type, 'velocity', 'position', and 'motor' (torque control) "policy_type": "motor", # tracking_controller type, 'velocity', 'position', and 'motor' (torque control)
"weights_scale": 1, # scaling of MP weights "weights_scale": 1, # scaling of MP weights
"goal_scale": 1, # scaling of learned goal position "goal_scale": 1, # scaling of learned goal position
"policy_kwargs": { # only required for torque control/PD-Controller "policy_kwargs": { # only required for torque control/PD-Controller
@ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
# "frame_skip": 1 # "frame_skip": 1
} }
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples): # OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples):
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
# This renders the full MP trajectory # This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory. # It is only required to call render() once in the beginning, which renders every consecutive trajectory.

View File

@ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"width": 0.025, # width of the basis functions "width": 0.025, # width of the basis functions
"zero_start": True, # start from current environment position if True "zero_start": True, # start from current environment position if True
"weights_scale": 1, # scaling of MP weights "weights_scale": 1, # scaling of MP weights
"policy_type": "metaworld", # custom controller type for metaworld environments "policy_type": "metaworld", # custom tracking_controller type for metaworld environments
} }
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a DMP (other mp_kwargs are required, see dmc_examples): # OR for a DMP (other traj_gen_kwargs are required, see dmc_examples):
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) # env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs)
# This renders the full MP trajectory # This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory. # It is only required to call render() once in the beginning, which renders every consecutive trajectory.

View File

@ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=
Returns: Returns:
""" """
# Changing the mp_kwargs is possible by providing them to gym. # Changing the traj_gen_kwargs is possible by providing them to gym.
# E.g. here by providing way to many basis functions # E.g. here by providing way to many basis functions
mp_kwargs = { mp_kwargs = {
"num_dof": 5, "num_dof": 5,
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
} }
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
# OR for a deterministic ProMP: # OR for a deterministic ProMP:
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) # env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs)
if render: if render:
env.render(mode="human") env.render(mode="human")

View File

@ -4,7 +4,7 @@ import alr_envs
def example_mp(env_name, seed=1): def example_mp(env_name, seed=1):
""" """
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered. Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
For more information on motion primitive specific stuff, look at the mp examples. For more information on motion primitive specific stuff, look at the trajectory_generator examples.
Args: Args:
env_name: ProMP env_id env_name: ProMP env_id
seed: seed seed: seed

View File

@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env
def visualize(env): def visualize(env):
t = env.t t = env.t
pos_features = env.mp.basis_generator.basis(t) pos_features = env.trajectory_generator.basis_generator.basis(t)
plt.plot(t, pos_features) plt.plot(t, pos_features)
plt.show() plt.show()

View File

@ -19,7 +19,7 @@ for _task in _goal_change_envs:
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_change_mp_wrapper.MPWrapper], "wrappers": [goal_change_mp_wrapper.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
@ -42,7 +42,7 @@ for _task in _object_change_envs:
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [object_change_mp_wrapper.MPWrapper], "wrappers": [object_change_mp_wrapper.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
@ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs:
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_object_change_mp_wrapper.MPWrapper], "wrappers": [goal_object_change_mp_wrapper.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,
@ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs:
kwargs={ kwargs={
"name": _task, "name": _task,
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper], "wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 6.25, "duration": 6.25,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC
from typing import Union, Tuple from typing import Tuple
import gym import gym
import numpy as np import numpy as np
@ -7,77 +7,77 @@ from gym import spaces
from mp_pytorch.mp.mp_interfaces import MPInterface from mp_pytorch.mp.mp_interfaces import MPInterface
from alr_envs.mp.controllers.base_controller import BaseController from alr_envs.mp.controllers.base_controller import BaseController
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC): class BlackBoxWrapper(gym.ObservationWrapper, ABC):
"""
Base class for movement primitive based gym.Wrapper implementations.
Args: def __init__(self,
env: The (wrapped) environment this wrapper is applied on env: RawInterfaceWrapper,
num_dof: Dimension of the action space of the wrapped env trajectory_generator: MPInterface, tracking_controller: BaseController,
num_basis: Number of basis functions per dof duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum):
duration: Length of the trajectory of the movement primitive in seconds """
controller: Type or object defining the policy that is used to generate action based on the trajectory gym.Wrapper for leveraging a black box approach with a trajectory generator.
weight_scale: Scaling parameter for the actions given to this wrapper
render_mode: Equivalent to gym render mode
"""
def __init__( Args:
self, env: The (wrapped) environment this wrapper is applied on
env: gym.Env, trajectory_generator: Generates the full or partial trajectory
mp: MPInterface, tracking_controller: Translates the desired trajectory to raw action sequences
controller: BaseController, duration: Length of the trajectory of the movement primitive in seconds
duration: float, verbose: level of detail for returned values in info dict.
render_mode: str = None, reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
verbose: int = 1, reward, default summation over all values.
weight_scale: float = 1, """
sequencing=True,
reward_aggregation=np.mean,
):
super().__init__() super().__init__()
self.env = env self.env = env
try:
self.dt = env.dt
except AttributeError:
raise AttributeError("step based environment needs to have a function 'dt' ")
self.duration = duration self.duration = duration
self.traj_steps = int(duration / self.dt) self.traj_steps = int(duration / self.dt)
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
# duration = self.env.max_episode_steps * self.dt # duration = self.env.max_episode_steps * self.dt
self.mp = mp # trajectory generation
self.env = env self.trajectory_generator = trajectory_generator
self.controller = controller self.tracking_controller = tracking_controller
self.weight_scale = weight_scale # self.weight_scale = weight_scale
# rendering
self.render_mode = render_mode
self.render_kwargs = {}
self.time_steps = np.linspace(0, self.duration, self.traj_steps) self.time_steps = np.linspace(0, self.duration, self.traj_steps)
self.mp.set_mp_times(self.time_steps) self.trajectory_generator.set_mp_times(self.time_steps)
# self.mp.set_mp_duration(self.time_steps, dt) # self.trajectory_generator.set_mp_duration(self.time_steps, dt)
# action_bounds = np.inf * np.ones((np.prod(self.mp.num_params))) # action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
self.mp_action_space = self.get_mp_action_space() self.reward_aggregation = reward_aggregation
# spaces
self.mp_action_space = self.get_mp_action_space()
self.action_space = self.get_action_space() self.action_space = self.get_action_space()
self.active_obs = self.set_active_obs() self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask],
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs], high=self.env.observation_space.high[self.env.context_mask],
high=self.env.observation_space.high[self.active_obs],
dtype=self.env.observation_space.dtype) dtype=self.env.observation_space.dtype)
# rendering
self.render_mode = None
self.render_kwargs = {}
self.verbose = verbose self.verbose = verbose
@property
def dt(self):
return self.env.dt
def observation(self, observation):
return observation[self.env.context_mask]
def get_trajectory(self, action: np.ndarray) -> Tuple: def get_trajectory(self, action: np.ndarray) -> Tuple:
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at # TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
# the beginning of the array. # the beginning of the array.
ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay) # ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
scaled_mp_params = action.copy() # scaled_mp_params = action.copy()
scaled_mp_params[ignore_indices:] *= self.weight_scale # scaled_mp_params[ignore_indices:] *= self.weight_scale
self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high))
self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel) clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True) self.trajectory_generator.set_params(clipped_params)
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
bc_vel=self.current_vel)
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel'] trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
trajectory = trajectory_tensor.numpy() trajectory = trajectory_tensor.numpy()
@ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
# TODO: Do we need this or does mp_pytorch have this? # TODO: Do we need this or does mp_pytorch have this?
if self.post_traj_steps > 0: if self.post_traj_steps > 0:
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])]) trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))]) velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))])
return trajectory, velocity return trajectory, velocity
def get_mp_action_space(self): def get_mp_action_space(self):
"""This function can be used to set up an individual space for the parameters of the mp.""" """This function can be used to set up an individual space for the parameters of the trajectory_generator."""
min_action_bounds, max_action_bounds = self.mp.get_param_bounds() min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds()
mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(), mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
dtype=np.float32) dtype=np.float32)
return mp_action_space return mp_action_space
@ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
except AttributeError: except AttributeError:
return self.get_mp_action_space() return self.get_mp_action_space()
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
"""
Used to extract the parameters for the motion primitive and other parameters from an action array which might
include other actions like ball releasing time for the beer pong environment.
This only needs to be overwritten if the action space is modified.
Args:
action: a vector instance of the whole action space, includes mp parameters and additional parameters if
specified, else only mp parameters
Returns:
Tuple: mp_arguments and other arguments
"""
return action, None
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
np.ndarray]:
"""
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
Beerpong env. The parameters used should not be part of the motion primitive parameters.
Returns step_action by default, can be overwritten in individual mp_wrappers.
Args:
t: the current time step of the episode
env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback
(e.g. ball release time in Beer Pong)
step_action: the current step-based action
Returns:
modified step action
"""
return step_action
@abstractmethod
def set_active_obs(self) -> np.ndarray:
"""
This function defines the contexts. The contexts are defined as specific observations.
Returns:
boolearn array representing the indices of the observations
"""
return np.ones(self.env.observation_space.shape[0], dtype=bool)
@property
@abstractmethod
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
"""
Returns the current position of the action/control dimension.
The dimensionality has to match the action/control dimension.
This is not required when exclusively using velocity control,
it should, however, be implemented regardless.
E.g. The joint positions that are directly or indirectly controlled by the action.
"""
raise NotImplementedError()
@property
@abstractmethod
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
"""
Returns the current velocity of the action/control dimension.
The dimensionality has to match the action/control dimension.
This is not required when exclusively using position control,
it should, however, be implemented regardless.
E.g. The joint velocities that are directly or indirectly controlled by the action.
"""
raise NotImplementedError()
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
# TODO: Think about sequencing # TODO: Think about sequencing
@ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
# TODO # TODO
# self.time_steps = np.linspace(0, learned_duration, self.traj_steps) # self.time_steps = np.linspace(0, learned_duration, self.traj_steps)
# self.mp.set_mp_times(self.time_steps) # self.trajectory_generator.set_mp_times(self.time_steps)
trajectory_length = len(trajectory) trajectory_length = len(trajectory)
rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2: if self.verbose >= 2:
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape) actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape, observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype) dtype=self.env.observation_space.dtype)
rewards = np.zeros(shape=(trajectory_length,))
trajectory_return = 0
infos = dict() infos = dict()
done = False
for t, pos_vel in enumerate(zip(trajectory, velocity)): for t, pos_vel in enumerate(zip(trajectory, velocity)):
step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel) step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos,
self.current_vel)
step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high) c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
# print('step/clipped action ratio: ', step_action/c_action) # print('step/clipped action ratio: ', step_action/c_action)
obs, c_reward, done, info = self.env.step(c_action) obs, c_reward, done, info = self.env.step(c_action)
rewards[t] = c_reward
if self.verbose >= 2: if self.verbose >= 2:
actions[t, :] = c_action actions[t, :] = c_action
rewards[t] = c_reward
observations[t, :] = obs observations[t, :] = obs
trajectory_return += c_reward
for k, v in info.items(): for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length) elems = infos.get(k, [None] * trajectory_length)
elems[t] = v elems[t] = v
infos[k] = elems infos[k] = elems
# infos['step_infos'].append(info)
if self.render_mode: if self.render_mode is not None:
self.render(mode=self.render_mode, **self.render_kwargs) self.render(mode=self.render_mode, **self.render_kwargs)
if done or do_replanning(kwargs):
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
break break
infos.update({k: v[:t + 1] for k, v in infos.items()}) infos.update({k: v[:t + 1] for k, v in infos.items()})
if self.verbose >= 2: if self.verbose >= 2:
infos['trajectory'] = trajectory infos['trajectory'] = trajectory
infos['step_actions'] = actions[:t + 1] infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1] infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1] infos['step_rewards'] = rewards[:t + 1]
infos['trajectory_length'] = t + 1 infos['trajectory_length'] = t + 1
done = True trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.get_observation_from_step(obs), trajectory_return, done, infos return self.get_observation_from_step(obs), trajectory_return, done, infos
def reset(self): def reset(self):

View File

@ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController
class MetaWorldController(BaseController): class MetaWorldController(BaseController):
""" """
A Metaworld Controller. Using position and velocity information from a provided environment, A Metaworld Controller. Using position and velocity information from a provided environment,
the controller calculates a response based on the desired position and velocity. the tracking_controller calculates a response based on the desired position and velocity.
Unlike the other Controllers, this is a special controller for MetaWorld environments. Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
They use a position delta for the xyz coordinates and a raw position for the gripper opening. They use a position delta for the xyz coordinates and a raw position for the gripper opening.
:param env: A position environment :param env: A position environment

View File

@ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
class PDController(BaseController): class PDController(BaseController):
""" """
A PD-Controller. Using position and velocity information from a provided environment, A PD-Controller. Using position and velocity information from a provided environment,
the controller calculates a response based on the desired position and velocity the tracking_controller calculates a response based on the desired position and velocity
:param env: A position environment :param env: A position environment
:param p_gains: Factors for the proportional gains :param p_gains: Factors for the proportional gains

View File

@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
class PosController(BaseController): class PosController(BaseController):
""" """
A Position Controller. The controller calculates a response only based on the desired position. A Position Controller. The tracking_controller calculates a response only based on the desired position.
""" """
def get_action(self, des_pos, des_vel, c_pos, c_vel): def get_action(self, des_pos, des_vel, c_pos, c_vel):
return des_pos return des_pos

View File

@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
class VelController(BaseController): class VelController(BaseController):
""" """
A Velocity Controller. The controller calculates a response only based on the desired velocity. A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity.
""" """
def get_action(self, des_pos, des_vel, c_pos, c_vel): def get_action(self, des_pos, des_vel, c_pos, c_vel):
return des_vel return des_vel

View File

@ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator
ALL_TYPES = ["promp", "dmp", "idmp"] ALL_TYPES = ["promp", "dmp", "idmp"]
def get_movement_primitive( def get_trajectory_generator(
movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
): ):
movement_primitives_type = movement_primitives_type.lower() trajectory_generator_type = trajectory_generator_type.lower()
if movement_primitives_type == "promp": if trajectory_generator_type == "promp":
return ProMP(basis_generator, action_dim, **kwargs) return ProMP(basis_generator, action_dim, **kwargs)
elif movement_primitives_type == "dmp": elif trajectory_generator_type == "dmp":
return DMP(basis_generator, action_dim, **kwargs) return DMP(basis_generator, action_dim, **kwargs)
elif movement_primitives_type == 'idmp': elif trajectory_generator_type == 'idmp':
return IDMP(basis_generator, action_dim, **kwargs) return IDMP(basis_generator, action_dim, **kwargs)
else: else:
raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, " raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
f"please choose one of {ALL_TYPES}.") f"please choose one of {ALL_TYPES}.")

View File

@ -0,0 +1,88 @@
from typing import Union, Tuple
import gym
import numpy as np
from abc import abstractmethod
class RawInterfaceWrapper(gym.Wrapper):
@property
@abstractmethod
def context_mask(self) -> np.ndarray:
"""
This function defines the contexts. The contexts are defined as specific observations.
Returns:
bool array representing the indices of the observations
"""
return np.ones(self.env.observation_space.shape[0], dtype=bool)
@property
@abstractmethod
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
"""
Returns the current position of the action/control dimension.
The dimensionality has to match the action/control dimension.
This is not required when exclusively using velocity control,
it should, however, be implemented regardless.
E.g. The joint positions that are directly or indirectly controlled by the action.
"""
raise NotImplementedError()
@property
@abstractmethod
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
"""
Returns the current velocity of the action/control dimension.
The dimensionality has to match the action/control dimension.
This is not required when exclusively using position control,
it should, however, be implemented regardless.
E.g. The joint velocities that are directly or indirectly controlled by the action.
"""
raise NotImplementedError()
@property
@abstractmethod
def dt(self) -> float:
"""
Control frequency of the environment
Returns: float
"""
def do_replanning(self, pos, vel, s, a, t):
# return t % 100 == 0
# return bool(self.replanning_model(s))
return False
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
"""
Used to extract the parameters for the motion primitive and other parameters from an action array which might
include other actions like ball releasing time for the beer pong environment.
This only needs to be overwritten if the action space is modified.
Args:
action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if
specified, else only trajectory_generator parameters
Returns:
Tuple: mp_arguments and other arguments
"""
return action, None
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
np.ndarray]:
"""
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
Beerpong env. The parameters used should not be part of the motion primitive parameters.
Returns step_action by default, can be overwritten in individual mp_wrappers.
Args:
t: the current time step of the episode
env_spec_params: the environment specific parameter, as defined in function _episode_callback
(e.g. ball release time in Beer Pong)
step_action: the current step-based action
Returns:
modified step action
"""
return step_action

View File

@ -21,7 +21,7 @@ register(
kwargs={ kwargs={
"name": "alr_envs:MountainCarContinuous-v1", "name": "alr_envs:MountainCarContinuous-v1",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper], "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 4, "num_basis": 4,
"duration": 2, "duration": 2,
@ -43,7 +43,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.classic_control:MountainCarContinuous-v0", "name": "gym.envs.classic_control:MountainCarContinuous-v0",
"wrappers": [classic_control.continuous_mountain_car.MPWrapper], "wrappers": [classic_control.continuous_mountain_car.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 4, "num_basis": 4,
"duration": 19.98, "duration": 19.98,
@ -65,7 +65,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.mujoco:Reacher-v2", "name": "gym.envs.mujoco:Reacher-v2",
"wrappers": [mujoco.reacher_v2.MPWrapper], "wrappers": [mujoco.reacher_v2.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 6, "num_basis": 6,
"duration": 1, "duration": 1,
@ -87,7 +87,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchSlideDense-v1", "name": "gym.envs.robotics:FetchSlideDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -105,7 +105,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchSlide-v1", "name": "gym.envs.robotics:FetchSlide-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -123,7 +123,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchReachDense-v1", "name": "gym.envs.robotics:FetchReachDense-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -141,7 +141,7 @@ register(
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchReach-v1", "name": "gym.envs.robotics:FetchReach-v1",
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper], "wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
"mp_kwargs": { "traj_gen_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,

View File

@ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping
import gym import gym
import numpy as np import numpy as np
from gym.envs.registration import EnvSpec from gym.envs.registration import EnvSpec
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
from mp_pytorch import MPInterface from mp_pytorch import MPInterface
from alr_envs.mp.basis_generator_factory import get_basis_generator from alr_envs.mp.basis_generator_factory import get_basis_generator
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
from alr_envs.mp.controllers.base_controller import BaseController from alr_envs.mp.controllers.base_controller import BaseController
from alr_envs.mp.controllers.controller_factory import get_controller from alr_envs.mp.controllers.controller_factory import get_controller
from alr_envs.mp.mp_factory import get_movement_primitive from alr_envs.mp.mp_factory import get_trajectory_generator
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
from alr_envs.mp.phase_generator_factory import get_phase_generator from alr_envs.mp.phase_generator_factory import get_phase_generator
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs): def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
@ -100,9 +98,8 @@ def make(env_id: str, seed, **kwargs):
def _make_wrapped_env( def _make_wrapped_env(
env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController, env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs
ep_wrapper_kwargs: Mapping, seed=1, **kwargs ):
):
""" """
Helper function for creating a wrapped gym environment using MPs. Helper function for creating a wrapped gym environment using MPs.
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
@ -118,73 +115,74 @@ def _make_wrapped_env(
""" """
# _env = gym.make(env_id) # _env = gym.make(env_id)
_env = make(env_id, seed, **kwargs) _env = make(env_id, seed, **kwargs)
has_episodic_wrapper = False has_black_box_wrapper = False
for w in wrappers: for w in wrappers:
# only wrap the environment if not EpisodicWrapper, e.g. for vision # only wrap the environment if not BlackBoxWrapper, e.g. for vision
if not issubclass(w, EpisodicWrapper): if issubclass(w, RawInterfaceWrapper):
_env = w(_env) has_black_box_wrapper = True
else: # if EpisodicWrapper, use specific constructor _env = w(_env)
has_episodic_wrapper = True if not has_black_box_wrapper:
_env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs) raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.")
if not has_episodic_wrapper:
raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.")
return _env return _env
def make_mp_from_kwargs( def make_bb_env(
env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping, env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1, controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1,
sequenced=False, **kwargs sequenced=False, **kwargs):
):
""" """
This can also be used standalone for manually building a custom DMP environment. This can also be used standalone for manually building a custom DMP environment.
Args: Args:
ep_wrapper_kwargs: black_box_wrapper_kwargs: kwargs for the black-box wrapper
basis_kwargs: basis_kwargs: kwargs for the basis generator
phase_kwargs: phase_kwargs: kwargs for the phase generator
controller_kwargs: controller_kwargs: kwargs for the tracking controller
env_id: base_env_name, env_id: base_env_name,
wrappers: list of wrappers (at least an EpisodicWrapper), wrappers: list of wrappers (at least an BlackBoxWrapper),
seed: seed of environment seed: seed of environment
sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory, sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory,
this behavior is much closer to step based learning. this behavior is much closer to step based learning.
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
Returns: DMP wrapped gym env Returns: DMP wrapped gym env
""" """
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None))
dummy_env = make(env_id, seed) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
if ep_wrapper_kwargs.get('duration', None) is None:
ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt if black_box_wrapper_kwargs.get('duration', None) is None:
black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt
if phase_kwargs.get('tau', None) is None: if phase_kwargs.get('tau', None) is None:
phase_kwargs['tau'] = ep_wrapper_kwargs['duration'] phase_kwargs['tau'] = black_box_wrapper_kwargs['duration']
mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item()) traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item())
phase_gen = get_phase_generator(**phase_kwargs) phase_gen = get_phase_generator(**phase_kwargs)
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs) basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
controller = get_controller(**controller_kwargs) controller = get_controller(**controller_kwargs)
mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs) traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs)
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller,
ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs) bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller,
return _env **black_box_wrapper_kwargs)
return bb_env
def make_mp_env_helper(**kwargs): def make_bb_env_helper(**kwargs):
""" """
Helper function for registering a DMP gym environments. Helper function for registering a black box gym environment.
Args: Args:
**kwargs: expects at least the following: **kwargs: expects at least the following:
{ {
"name": base environment name. "name": base environment name.
"wrappers": list of wrappers (at least an EpisodicWrapper is required), "wrappers": list of wrappers (at least an BlackBoxWrapper is required),
"movement_primitives_kwargs": { "traj_gen_kwargs": {
"movement_primitives_type": type_of_your_movement_primitive, "trajectory_generator_type": type_of_your_movement_primitive,
non default arguments for the movement primitive instance non default arguments for the movement primitive instance
... ...
} }
"controller_kwargs": { "controller_kwargs": {
"controller_type": type_of_your_controller, "controller_type": type_of_your_controller,
non default arguments for the controller instance non default arguments for the tracking_controller instance
... ...
}, },
"basis_generator_kwargs": { "basis_generator_kwargs": {
@ -205,95 +203,17 @@ def make_mp_env_helper(**kwargs):
seed = kwargs.pop("seed", None) seed = kwargs.pop("seed", None)
wrappers = kwargs.pop("wrappers") wrappers = kwargs.pop("wrappers")
mp_kwargs = kwargs.pop("movement_primitives_kwargs") traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs') black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {})
contr_kwargs = kwargs.pop("controller_kwargs") contr_kwargs = kwargs.pop("controller_kwargs", {})
phase_kwargs = kwargs.pop("phase_generator_kwargs") phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
basis_kwargs = kwargs.pop("basis_generator_kwargs") basis_kwargs = kwargs.pop("basis_generator_kwargs", {})
return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs, return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers,
mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs, black_box_wrapper_kwargs=black_box_kwargs,
basis_kwargs=basis_kwargs, **kwargs, seed=seed) traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs,
phase_kwargs=phase_kwargs,
basis_kwargs=basis_kwargs, **kwargs, seed=seed)
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
"""
This can also be used standalone for manually building a custom DMP environment.
Args:
env_id: base_env_name,
wrappers: list of wrappers (at least an MPEnvWrapper),
seed: seed of environment
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
Returns: DMP wrapped gym env
"""
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
_verify_dof(_env, mp_kwargs.get("num_dof"))
return DmpWrapper(_env, **mp_kwargs)
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
"""
This can also be used standalone for manually building a custom ProMP environment.
Args:
env_id: base_env_name,
wrappers: list of wrappers (at least an MPEnvWrapper),
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
Returns: ProMP wrapped gym env
"""
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
_verify_dof(_env, mp_kwargs.get("num_dof"))
return ProMPWrapper(_env, **mp_kwargs)
def make_dmp_env_helper(**kwargs):
"""
Helper function for registering a DMP gym environments.
Args:
**kwargs: expects at least the following:
{
"name": base_env_name,
"wrappers": list of wrappers (at least an MPEnvWrapper),
"mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
}
Returns: DMP wrapped gym env
"""
seed = kwargs.pop("seed", None)
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def make_promp_env_helper(**kwargs):
"""
Helper function for registering ProMP gym environments.
This can also be used standalone for manually building a custom ProMP environment.
Args:
**kwargs: expects at least the following:
{
"name": base_env_name,
"wrappers": list of wrappers (at least an MPEnvWrapper),
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
}
Returns: ProMP wrapped gym env
"""
seed = kwargs.pop("seed", None)
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
@ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[
It can be found in the BaseMP class. It can be found in the BaseMP class.
Args: Args:
mp_time_limit: max trajectory length of mp in seconds mp_time_limit: max trajectory length of trajectory_generator in seconds
env_time_limit: max trajectory length of DMC environment in seconds env_time_limit: max trajectory length of DMC environment in seconds
Returns: Returns: