restructuring
This commit is contained in:
parent
8fe6a83271
commit
02b8a65bab
@ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper]
|
|||||||
mp_kwargs = {...}
|
mp_kwargs = {...}
|
||||||
kwargs = {...}
|
kwargs = {...}
|
||||||
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP (other mp_kwargs are required):
|
# OR for a deterministic ProMP (other traj_gen_kwargs are required):
|
||||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
@ -346,7 +346,7 @@ for _v in _versions:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{_v}",
|
"name": f"alr_envs:{_v}",
|
||||||
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -386,7 +386,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:ViaPointReacher-v0",
|
"name": "alr_envs:ViaPointReacher-v0",
|
||||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -424,7 +424,7 @@ for _v in _versions:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:HoleReacher-{_v}",
|
"name": f"alr_envs:HoleReacher-{_v}",
|
||||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -467,7 +467,7 @@ for _v in _versions:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{_v}",
|
"name": f"alr_envs:{_v}",
|
||||||
"wrappers": [mujoco.reacher.MPWrapper],
|
"wrappers": [mujoco.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 5 if "long" not in _v.lower() else 7,
|
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"duration": 4,
|
"duration": 4,
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
class NewMPWrapper(EpisodicWrapper):
|
class NewMPWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
def set_active_obs(self):
|
def get_context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 111, # ant has 111 dimensional observation space !!
|
[False] * 111, # ant has 111 dimensional observation space !!
|
||||||
[True] # goal height
|
[True] # goal height
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
from typing import Tuple, Union
|
from typing import Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
class NewMPWrapper(EpisodicWrapper):
|
class NewMPWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
# def __init__(self, replanning_model):
|
|
||||||
# self.replanning_model = replanning_model
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qpos[0:7].copy()
|
return self.env.sim.data.qpos[0:7].copy()
|
||||||
@ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qvel[0:7].copy()
|
return self.env.sim.data.qvel[0:7].copy()
|
||||||
|
|
||||||
def set_active_obs(self):
|
def get_context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # cos
|
[False] * 7, # cos
|
||||||
[False] * 7, # sin
|
[False] * 7, # sin
|
||||||
@ -29,11 +25,6 @@ class NewMPWrapper(EpisodicWrapper):
|
|||||||
[False] # env steps
|
[False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
def do_replanning(self, pos, vel, s, a, t, last_replan_step):
|
|
||||||
return False
|
|
||||||
# const = np.arange(0, 1000, 10)
|
|
||||||
# return bool(self.replanning_model(s))
|
|
||||||
|
|
||||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||||
if self.mp.learn_tau:
|
if self.mp.learn_tau:
|
||||||
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class NewMPWrapper(EpisodicWrapper):
|
class NewMPWrapper(BlackBoxWrapper):
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qpos[3:6].copy()
|
return self.env.sim.data.qpos[3:6].copy()
|
||||||
@ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper):
|
|||||||
# ])
|
# ])
|
||||||
|
|
||||||
# Random x goal + random init pos
|
# Random x goal + random init pos
|
||||||
def set_active_obs(self):
|
def get_context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
||||||
[True] * 3, # set to true if randomize initial pos
|
[True] * 3, # set to true if randomize initial pos
|
||||||
@ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper):
|
|||||||
|
|
||||||
|
|
||||||
class NewHighCtxtMPWrapper(NewMPWrapper):
|
class NewHighCtxtMPWrapper(NewMPWrapper):
|
||||||
def set_active_obs(self):
|
def get_context_mask(self):
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
||||||
[True] * 3, # set to true if randomize initial pos
|
[True] * 3, # set to true if randomize initial pos
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(EpisodicWrapper):
|
class MPWrapper(BlackBoxWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
@ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper):
|
|||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qvel.flat[:self.env.n_links]
|
return self.env.sim.data.qvel.flat[:self.env.n_links]
|
||||||
|
|
||||||
def set_active_obs(self):
|
def get_context_mask(self):
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
[False] * self.env.n_links, # cos
|
[False] * self.env.n_links, # cos
|
||||||
[False] * self.env.n_links, # sin
|
[False] * self.env.n_links, # sin
|
||||||
|
@ -15,7 +15,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.ball_in_cup.MPWrapper],
|
"wrappers": [suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -41,7 +41,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.ball_in_cup.MPWrapper],
|
"wrappers": [suite.ball_in_cup.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -65,7 +65,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.reacher.MPWrapper],
|
"wrappers": [suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -92,7 +92,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.reacher.MPWrapper],
|
"wrappers": [suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -117,7 +117,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.reacher.MPWrapper],
|
"wrappers": [suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -144,7 +144,7 @@ register(
|
|||||||
"time_limit": 20,
|
"time_limit": 20,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.reacher.MPWrapper],
|
"wrappers": [suite.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 20,
|
||||||
@ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.MPWrapper],
|
"wrappers": [suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.MPWrapper],
|
"wrappers": [suite.cartpole.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -230,7 +230,7 @@ register(
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -259,7 +259,7 @@ register(
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -286,7 +286,7 @@ register(
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -315,7 +315,7 @@ register(
|
|||||||
"camera_id": 0,
|
"camera_id": 0,
|
||||||
"episode_length": 1000,
|
"episode_length": 1000,
|
||||||
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -342,7 +342,7 @@ register(
|
|||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
"episode_length": 250,
|
"episode_length": 250,
|
||||||
"wrappers": [manipulation.reach_site.MPWrapper],
|
"wrappers": [manipulation.reach_site.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 9,
|
"num_dof": 9,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
@ -365,7 +365,7 @@ register(
|
|||||||
# "time_limit": 1,
|
# "time_limit": 1,
|
||||||
"episode_length": 250,
|
"episode_length": 250,
|
||||||
"wrappers": [manipulation.reach_site.MPWrapper],
|
"wrappers": [manipulation.reach_site.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 9,
|
"num_dof": 9,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 10,
|
"duration": 10,
|
||||||
|
@ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
"learn_goal": True, # learn the goal position (recommended)
|
"learn_goal": True, # learn the goal position (recommended)
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"bandwidth_factor": 2,
|
"bandwidth_factor": 2,
|
||||||
"policy_type": "motor", # controller type, 'velocity', 'position', and 'motor' (torque control)
|
"policy_type": "motor", # tracking_controller type, 'velocity', 'position', and 'motor' (torque control)
|
||||||
"weights_scale": 1, # scaling of MP weights
|
"weights_scale": 1, # scaling of MP weights
|
||||||
"goal_scale": 1, # scaling of learned goal position
|
"goal_scale": 1, # scaling of learned goal position
|
||||||
"policy_kwargs": { # only required for torque control/PD-Controller
|
"policy_kwargs": { # only required for torque control/PD-Controller
|
||||||
@ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# "frame_skip": 1
|
# "frame_skip": 1
|
||||||
}
|
}
|
||||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
|
# OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples):
|
||||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
|
||||||
|
|
||||||
# This renders the full MP trajectory
|
# This renders the full MP trajectory
|
||||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
|
@ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
"width": 0.025, # width of the basis functions
|
"width": 0.025, # width of the basis functions
|
||||||
"zero_start": True, # start from current environment position if True
|
"zero_start": True, # start from current environment position if True
|
||||||
"weights_scale": 1, # scaling of MP weights
|
"weights_scale": 1, # scaling of MP weights
|
||||||
"policy_type": "metaworld", # custom controller type for metaworld environments
|
"policy_type": "metaworld", # custom tracking_controller type for metaworld environments
|
||||||
}
|
}
|
||||||
|
|
||||||
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
|
# OR for a DMP (other traj_gen_kwargs are required, see dmc_examples):
|
||||||
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs)
|
||||||
|
|
||||||
# This renders the full MP trajectory
|
# This renders the full MP trajectory
|
||||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
|
@ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Changing the mp_kwargs is possible by providing them to gym.
|
# Changing the traj_gen_kwargs is possible by providing them to gym.
|
||||||
# E.g. here by providing way to many basis functions
|
# E.g. here by providing way to many basis functions
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
}
|
}
|
||||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs)
|
||||||
|
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render(mode="human")
|
||||||
|
@ -4,7 +4,7 @@ import alr_envs
|
|||||||
def example_mp(env_name, seed=1):
|
def example_mp(env_name, seed=1):
|
||||||
"""
|
"""
|
||||||
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
||||||
For more information on motion primitive specific stuff, look at the mp examples.
|
For more information on motion primitive specific stuff, look at the trajectory_generator examples.
|
||||||
Args:
|
Args:
|
||||||
env_name: ProMP env_id
|
env_name: ProMP env_id
|
||||||
seed: seed
|
seed: seed
|
||||||
|
@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env
|
|||||||
|
|
||||||
def visualize(env):
|
def visualize(env):
|
||||||
t = env.t
|
t = env.t
|
||||||
pos_features = env.mp.basis_generator.basis(t)
|
pos_features = env.trajectory_generator.basis_generator.basis(t)
|
||||||
plt.plot(t, pos_features)
|
plt.plot(t, pos_features)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ for _task in _goal_change_envs:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
@ -42,7 +42,7 @@ for _task in _object_change_envs:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
@ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
@ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs:
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": _task,
|
"name": _task,
|
||||||
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 6.25,
|
"duration": 6.25,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from typing import Union, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -7,77 +7,77 @@ from gym import spaces
|
|||||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||||
|
|
||||||
from alr_envs.mp.controllers.base_controller import BaseController
|
from alr_envs.mp.controllers.base_controller import BaseController
|
||||||
|
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
env: RawInterfaceWrapper,
|
||||||
|
trajectory_generator: MPInterface, tracking_controller: BaseController,
|
||||||
|
duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum):
|
||||||
"""
|
"""
|
||||||
Base class for movement primitive based gym.Wrapper implementations.
|
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The (wrapped) environment this wrapper is applied on
|
env: The (wrapped) environment this wrapper is applied on
|
||||||
num_dof: Dimension of the action space of the wrapped env
|
trajectory_generator: Generates the full or partial trajectory
|
||||||
num_basis: Number of basis functions per dof
|
tracking_controller: Translates the desired trajectory to raw action sequences
|
||||||
duration: Length of the trajectory of the movement primitive in seconds
|
duration: Length of the trajectory of the movement primitive in seconds
|
||||||
controller: Type or object defining the policy that is used to generate action based on the trajectory
|
verbose: level of detail for returned values in info dict.
|
||||||
weight_scale: Scaling parameter for the actions given to this wrapper
|
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
|
||||||
render_mode: Equivalent to gym render mode
|
reward, default summation over all values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
env: gym.Env,
|
|
||||||
mp: MPInterface,
|
|
||||||
controller: BaseController,
|
|
||||||
duration: float,
|
|
||||||
render_mode: str = None,
|
|
||||||
verbose: int = 1,
|
|
||||||
weight_scale: float = 1,
|
|
||||||
sequencing=True,
|
|
||||||
reward_aggregation=np.mean,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.env = env
|
self.env = env
|
||||||
try:
|
|
||||||
self.dt = env.dt
|
|
||||||
except AttributeError:
|
|
||||||
raise AttributeError("step based environment needs to have a function 'dt' ")
|
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.traj_steps = int(duration / self.dt)
|
self.traj_steps = int(duration / self.dt)
|
||||||
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||||
# duration = self.env.max_episode_steps * self.dt
|
# duration = self.env.max_episode_steps * self.dt
|
||||||
|
|
||||||
self.mp = mp
|
# trajectory generation
|
||||||
self.env = env
|
self.trajectory_generator = trajectory_generator
|
||||||
self.controller = controller
|
self.tracking_controller = tracking_controller
|
||||||
self.weight_scale = weight_scale
|
# self.weight_scale = weight_scale
|
||||||
|
|
||||||
# rendering
|
|
||||||
self.render_mode = render_mode
|
|
||||||
self.render_kwargs = {}
|
|
||||||
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||||
self.mp.set_mp_times(self.time_steps)
|
self.trajectory_generator.set_mp_times(self.time_steps)
|
||||||
# self.mp.set_mp_duration(self.time_steps, dt)
|
# self.trajectory_generator.set_mp_duration(self.time_steps, dt)
|
||||||
# action_bounds = np.inf * np.ones((np.prod(self.mp.num_params)))
|
# action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
|
||||||
self.mp_action_space = self.get_mp_action_space()
|
self.reward_aggregation = reward_aggregation
|
||||||
|
|
||||||
|
# spaces
|
||||||
|
self.mp_action_space = self.get_mp_action_space()
|
||||||
self.action_space = self.get_action_space()
|
self.action_space = self.get_action_space()
|
||||||
self.active_obs = self.set_active_obs()
|
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask],
|
||||||
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs],
|
high=self.env.observation_space.high[self.env.context_mask],
|
||||||
high=self.env.observation_space.high[self.active_obs],
|
|
||||||
dtype=self.env.observation_space.dtype)
|
dtype=self.env.observation_space.dtype)
|
||||||
|
|
||||||
|
# rendering
|
||||||
|
self.render_mode = None
|
||||||
|
self.render_kwargs = {}
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self):
|
||||||
|
return self.env.dt
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return observation[self.env.context_mask]
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
|
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
|
||||||
# the beginning of the array.
|
# the beginning of the array.
|
||||||
ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay)
|
# ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
|
||||||
scaled_mp_params = action.copy()
|
# scaled_mp_params = action.copy()
|
||||||
scaled_mp_params[ignore_indices:] *= self.weight_scale
|
# scaled_mp_params[ignore_indices:] *= self.weight_scale
|
||||||
self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high))
|
|
||||||
self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel)
|
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
|
||||||
traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True)
|
self.trajectory_generator.set_params(clipped_params)
|
||||||
|
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
|
||||||
|
bc_vel=self.current_vel)
|
||||||
|
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
|
||||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||||
|
|
||||||
trajectory = trajectory_tensor.numpy()
|
trajectory = trajectory_tensor.numpy()
|
||||||
@ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
|||||||
# TODO: Do we need this or does mp_pytorch have this?
|
# TODO: Do we need this or does mp_pytorch have this?
|
||||||
if self.post_traj_steps > 0:
|
if self.post_traj_steps > 0:
|
||||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))])
|
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))])
|
||||||
|
|
||||||
return trajectory, velocity
|
return trajectory, velocity
|
||||||
|
|
||||||
def get_mp_action_space(self):
|
def get_mp_action_space(self):
|
||||||
"""This function can be used to set up an individual space for the parameters of the mp."""
|
"""This function can be used to set up an individual space for the parameters of the trajectory_generator."""
|
||||||
min_action_bounds, max_action_bounds = self.mp.get_param_bounds()
|
min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds()
|
||||||
mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
|
mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
|
||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
return mp_action_space
|
return mp_action_space
|
||||||
@ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
return self.get_mp_action_space()
|
return self.get_mp_action_space()
|
||||||
|
|
||||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
|
||||||
"""
|
|
||||||
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
|
||||||
include other actions like ball releasing time for the beer pong environment.
|
|
||||||
This only needs to be overwritten if the action space is modified.
|
|
||||||
Args:
|
|
||||||
action: a vector instance of the whole action space, includes mp parameters and additional parameters if
|
|
||||||
specified, else only mp parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: mp_arguments and other arguments
|
|
||||||
"""
|
|
||||||
return action, None
|
|
||||||
|
|
||||||
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
|
|
||||||
np.ndarray]:
|
|
||||||
"""
|
|
||||||
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
|
|
||||||
Beerpong env. The parameters used should not be part of the motion primitive parameters.
|
|
||||||
Returns step_action by default, can be overwritten in individual mp_wrappers.
|
|
||||||
Args:
|
|
||||||
t: the current time step of the episode
|
|
||||||
env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback
|
|
||||||
(e.g. ball release time in Beer Pong)
|
|
||||||
step_action: the current step-based action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
modified step action
|
|
||||||
"""
|
|
||||||
return step_action
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_active_obs(self) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
This function defines the contexts. The contexts are defined as specific observations.
|
|
||||||
Returns:
|
|
||||||
boolearn array representing the indices of the observations
|
|
||||||
|
|
||||||
"""
|
|
||||||
return np.ones(self.env.observation_space.shape[0], dtype=bool)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
"""
|
|
||||||
Returns the current position of the action/control dimension.
|
|
||||||
The dimensionality has to match the action/control dimension.
|
|
||||||
This is not required when exclusively using velocity control,
|
|
||||||
it should, however, be implemented regardless.
|
|
||||||
E.g. The joint positions that are directly or indirectly controlled by the action.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
"""
|
|
||||||
Returns the current velocity of the action/control dimension.
|
|
||||||
The dimensionality has to match the action/control dimension.
|
|
||||||
This is not required when exclusively using position control,
|
|
||||||
it should, however, be implemented regardless.
|
|
||||||
E.g. The joint velocities that are directly or indirectly controlled by the action.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
# TODO: Think about sequencing
|
# TODO: Think about sequencing
|
||||||
@ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
|||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# self.time_steps = np.linspace(0, learned_duration, self.traj_steps)
|
# self.time_steps = np.linspace(0, learned_duration, self.traj_steps)
|
||||||
# self.mp.set_mp_times(self.time_steps)
|
# self.trajectory_generator.set_mp_times(self.time_steps)
|
||||||
|
|
||||||
trajectory_length = len(trajectory)
|
trajectory_length = len(trajectory)
|
||||||
|
rewards = np.zeros(shape=(trajectory_length,))
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
|
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
|
||||||
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
|
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
|
||||||
dtype=self.env.observation_space.dtype)
|
dtype=self.env.observation_space.dtype)
|
||||||
rewards = np.zeros(shape=(trajectory_length,))
|
|
||||||
trajectory_return = 0
|
|
||||||
|
|
||||||
infos = dict()
|
infos = dict()
|
||||||
|
done = False
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||||
step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos,
|
||||||
|
self.current_vel)
|
||||||
step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info
|
step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info
|
||||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||||
# print('step/clipped action ratio: ', step_action/c_action)
|
# print('step/clipped action ratio: ', step_action/c_action)
|
||||||
obs, c_reward, done, info = self.env.step(c_action)
|
obs, c_reward, done, info = self.env.step(c_action)
|
||||||
|
rewards[t] = c_reward
|
||||||
|
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
actions[t, :] = c_action
|
actions[t, :] = c_action
|
||||||
rewards[t] = c_reward
|
|
||||||
observations[t, :] = obs
|
observations[t, :] = obs
|
||||||
trajectory_return += c_reward
|
|
||||||
for k, v in info.items():
|
for k, v in info.items():
|
||||||
elems = infos.get(k, [None] * trajectory_length)
|
elems = infos.get(k, [None] * trajectory_length)
|
||||||
elems[t] = v
|
elems[t] = v
|
||||||
infos[k] = elems
|
infos[k] = elems
|
||||||
# infos['step_infos'].append(info)
|
|
||||||
if self.render_mode:
|
if self.render_mode is not None:
|
||||||
self.render(mode=self.render_mode, **self.render_kwargs)
|
self.render(mode=self.render_mode, **self.render_kwargs)
|
||||||
if done or do_replanning(kwargs):
|
|
||||||
|
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
|
||||||
break
|
break
|
||||||
|
|
||||||
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||||
|
|
||||||
if self.verbose >= 2:
|
if self.verbose >= 2:
|
||||||
infos['trajectory'] = trajectory
|
infos['trajectory'] = trajectory
|
||||||
infos['step_actions'] = actions[:t + 1]
|
infos['step_actions'] = actions[:t + 1]
|
||||||
infos['step_observations'] = observations[:t + 1]
|
infos['step_observations'] = observations[:t + 1]
|
||||||
infos['step_rewards'] = rewards[:t + 1]
|
infos['step_rewards'] = rewards[:t + 1]
|
||||||
|
|
||||||
infos['trajectory_length'] = t + 1
|
infos['trajectory_length'] = t + 1
|
||||||
done = True
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.get_observation_from_step(obs), trajectory_return, done, infos
|
return self.get_observation_from_step(obs), trajectory_return, done, infos
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
@ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
|||||||
class MetaWorldController(BaseController):
|
class MetaWorldController(BaseController):
|
||||||
"""
|
"""
|
||||||
A Metaworld Controller. Using position and velocity information from a provided environment,
|
A Metaworld Controller. Using position and velocity information from a provided environment,
|
||||||
the controller calculates a response based on the desired position and velocity.
|
the tracking_controller calculates a response based on the desired position and velocity.
|
||||||
Unlike the other Controllers, this is a special controller for MetaWorld environments.
|
Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
|
||||||
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
||||||
|
|
||||||
:param env: A position environment
|
:param env: A position environment
|
||||||
|
@ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
|||||||
class PDController(BaseController):
|
class PDController(BaseController):
|
||||||
"""
|
"""
|
||||||
A PD-Controller. Using position and velocity information from a provided environment,
|
A PD-Controller. Using position and velocity information from a provided environment,
|
||||||
the controller calculates a response based on the desired position and velocity
|
the tracking_controller calculates a response based on the desired position and velocity
|
||||||
|
|
||||||
:param env: A position environment
|
:param env: A position environment
|
||||||
:param p_gains: Factors for the proportional gains
|
:param p_gains: Factors for the proportional gains
|
||||||
|
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
|||||||
|
|
||||||
class PosController(BaseController):
|
class PosController(BaseController):
|
||||||
"""
|
"""
|
||||||
A Position Controller. The controller calculates a response only based on the desired position.
|
A Position Controller. The tracking_controller calculates a response only based on the desired position.
|
||||||
"""
|
"""
|
||||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
return des_pos
|
return des_pos
|
||||||
|
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
|||||||
|
|
||||||
class VelController(BaseController):
|
class VelController(BaseController):
|
||||||
"""
|
"""
|
||||||
A Velocity Controller. The controller calculates a response only based on the desired velocity.
|
A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity.
|
||||||
"""
|
"""
|
||||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
return des_vel
|
return des_vel
|
||||||
|
@ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
|||||||
ALL_TYPES = ["promp", "dmp", "idmp"]
|
ALL_TYPES = ["promp", "dmp", "idmp"]
|
||||||
|
|
||||||
|
|
||||||
def get_movement_primitive(
|
def get_trajectory_generator(
|
||||||
movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
||||||
):
|
):
|
||||||
movement_primitives_type = movement_primitives_type.lower()
|
trajectory_generator_type = trajectory_generator_type.lower()
|
||||||
if movement_primitives_type == "promp":
|
if trajectory_generator_type == "promp":
|
||||||
return ProMP(basis_generator, action_dim, **kwargs)
|
return ProMP(basis_generator, action_dim, **kwargs)
|
||||||
elif movement_primitives_type == "dmp":
|
elif trajectory_generator_type == "dmp":
|
||||||
return DMP(basis_generator, action_dim, **kwargs)
|
return DMP(basis_generator, action_dim, **kwargs)
|
||||||
elif movement_primitives_type == 'idmp':
|
elif trajectory_generator_type == 'idmp':
|
||||||
return IDMP(basis_generator, action_dim, **kwargs)
|
return IDMP(basis_generator, action_dim, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, "
|
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
||||||
f"please choose one of {ALL_TYPES}.")
|
f"please choose one of {ALL_TYPES}.")
|
88
alr_envs/mp/raw_interface_wrapper.py
Normal file
88
alr_envs/mp/raw_interface_wrapper.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class RawInterfaceWrapper(gym.Wrapper):
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def context_mask(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
This function defines the contexts. The contexts are defined as specific observations.
|
||||||
|
Returns:
|
||||||
|
bool array representing the indices of the observations
|
||||||
|
|
||||||
|
"""
|
||||||
|
return np.ones(self.env.observation_space.shape[0], dtype=bool)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
"""
|
||||||
|
Returns the current position of the action/control dimension.
|
||||||
|
The dimensionality has to match the action/control dimension.
|
||||||
|
This is not required when exclusively using velocity control,
|
||||||
|
it should, however, be implemented regardless.
|
||||||
|
E.g. The joint positions that are directly or indirectly controlled by the action.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
"""
|
||||||
|
Returns the current velocity of the action/control dimension.
|
||||||
|
The dimensionality has to match the action/control dimension.
|
||||||
|
This is not required when exclusively using position control,
|
||||||
|
it should, however, be implemented regardless.
|
||||||
|
E.g. The joint velocities that are directly or indirectly controlled by the action.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dt(self) -> float:
|
||||||
|
"""
|
||||||
|
Control frequency of the environment
|
||||||
|
Returns: float
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def do_replanning(self, pos, vel, s, a, t):
|
||||||
|
# return t % 100 == 0
|
||||||
|
# return bool(self.replanning_model(s))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||||
|
"""
|
||||||
|
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
||||||
|
include other actions like ball releasing time for the beer pong environment.
|
||||||
|
This only needs to be overwritten if the action space is modified.
|
||||||
|
Args:
|
||||||
|
action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if
|
||||||
|
specified, else only trajectory_generator parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: mp_arguments and other arguments
|
||||||
|
"""
|
||||||
|
return action, None
|
||||||
|
|
||||||
|
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
|
||||||
|
np.ndarray]:
|
||||||
|
"""
|
||||||
|
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
|
||||||
|
Beerpong env. The parameters used should not be part of the motion primitive parameters.
|
||||||
|
Returns step_action by default, can be overwritten in individual mp_wrappers.
|
||||||
|
Args:
|
||||||
|
t: the current time step of the episode
|
||||||
|
env_spec_params: the environment specific parameter, as defined in function _episode_callback
|
||||||
|
(e.g. ball release time in Beer Pong)
|
||||||
|
step_action: the current step-based action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
modified step action
|
||||||
|
"""
|
||||||
|
return step_action
|
@ -21,7 +21,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:MountainCarContinuous-v1",
|
"name": "alr_envs:MountainCarContinuous-v1",
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -43,7 +43,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
"duration": 19.98,
|
"duration": 19.98,
|
||||||
@ -65,7 +65,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.mujoco:Reacher-v2",
|
"name": "gym.envs.mujoco:Reacher-v2",
|
||||||
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 6,
|
"num_basis": 6,
|
||||||
"duration": 1,
|
"duration": 1,
|
||||||
@ -87,7 +87,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -105,7 +105,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlide-v1",
|
"name": "gym.envs.robotics:FetchSlide-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -123,7 +123,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -141,7 +141,7 @@ register(
|
|||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReach-v1",
|
"name": "gym.envs.robotics:FetchReach-v1",
|
||||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
|
@ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.envs.registration import EnvSpec
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
|
||||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
|
||||||
from mp_pytorch import MPInterface
|
from mp_pytorch import MPInterface
|
||||||
|
|
||||||
from alr_envs.mp.basis_generator_factory import get_basis_generator
|
from alr_envs.mp.basis_generator_factory import get_basis_generator
|
||||||
|
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||||
from alr_envs.mp.controllers.base_controller import BaseController
|
from alr_envs.mp.controllers.base_controller import BaseController
|
||||||
from alr_envs.mp.controllers.controller_factory import get_controller
|
from alr_envs.mp.controllers.controller_factory import get_controller
|
||||||
from alr_envs.mp.mp_factory import get_movement_primitive
|
from alr_envs.mp.mp_factory import get_trajectory_generator
|
||||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
|
||||||
from alr_envs.mp.phase_generator_factory import get_phase_generator
|
from alr_envs.mp.phase_generator_factory import get_phase_generator
|
||||||
|
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||||
@ -100,9 +98,8 @@ def make(env_id: str, seed, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def _make_wrapped_env(
|
def _make_wrapped_env(
|
||||||
env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController,
|
env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs
|
||||||
ep_wrapper_kwargs: Mapping, seed=1, **kwargs
|
):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Helper function for creating a wrapped gym environment using MPs.
|
Helper function for creating a wrapped gym environment using MPs.
|
||||||
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
|
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
|
||||||
@ -118,73 +115,74 @@ def _make_wrapped_env(
|
|||||||
"""
|
"""
|
||||||
# _env = gym.make(env_id)
|
# _env = gym.make(env_id)
|
||||||
_env = make(env_id, seed, **kwargs)
|
_env = make(env_id, seed, **kwargs)
|
||||||
has_episodic_wrapper = False
|
has_black_box_wrapper = False
|
||||||
for w in wrappers:
|
for w in wrappers:
|
||||||
# only wrap the environment if not EpisodicWrapper, e.g. for vision
|
# only wrap the environment if not BlackBoxWrapper, e.g. for vision
|
||||||
if not issubclass(w, EpisodicWrapper):
|
if issubclass(w, RawInterfaceWrapper):
|
||||||
|
has_black_box_wrapper = True
|
||||||
_env = w(_env)
|
_env = w(_env)
|
||||||
else: # if EpisodicWrapper, use specific constructor
|
if not has_black_box_wrapper:
|
||||||
has_episodic_wrapper = True
|
raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.")
|
||||||
_env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs)
|
|
||||||
if not has_episodic_wrapper:
|
|
||||||
raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.")
|
|
||||||
return _env
|
return _env
|
||||||
|
|
||||||
|
|
||||||
def make_mp_from_kwargs(
|
def make_bb_env(
|
||||||
env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping,
|
env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
||||||
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1,
|
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1,
|
||||||
sequenced=False, **kwargs
|
sequenced=False, **kwargs):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom DMP environment.
|
This can also be used standalone for manually building a custom DMP environment.
|
||||||
Args:
|
Args:
|
||||||
ep_wrapper_kwargs:
|
black_box_wrapper_kwargs: kwargs for the black-box wrapper
|
||||||
basis_kwargs:
|
basis_kwargs: kwargs for the basis generator
|
||||||
phase_kwargs:
|
phase_kwargs: kwargs for the phase generator
|
||||||
controller_kwargs:
|
controller_kwargs: kwargs for the tracking controller
|
||||||
env_id: base_env_name,
|
env_id: base_env_name,
|
||||||
wrappers: list of wrappers (at least an EpisodicWrapper),
|
wrappers: list of wrappers (at least an BlackBoxWrapper),
|
||||||
seed: seed of environment
|
seed: seed of environment
|
||||||
sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory,
|
sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory,
|
||||||
this behavior is much closer to step based learning.
|
this behavior is much closer to step based learning.
|
||||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
||||||
|
|
||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
_verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
dummy_env = make(env_id, seed)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
if ep_wrapper_kwargs.get('duration', None) is None:
|
|
||||||
ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt
|
if black_box_wrapper_kwargs.get('duration', None) is None:
|
||||||
|
black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt
|
||||||
if phase_kwargs.get('tau', None) is None:
|
if phase_kwargs.get('tau', None) is None:
|
||||||
phase_kwargs['tau'] = ep_wrapper_kwargs['duration']
|
phase_kwargs['tau'] = black_box_wrapper_kwargs['duration']
|
||||||
mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item())
|
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item())
|
||||||
|
|
||||||
phase_gen = get_phase_generator(**phase_kwargs)
|
phase_gen = get_phase_generator(**phase_kwargs)
|
||||||
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
|
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
|
||||||
controller = get_controller(**controller_kwargs)
|
controller = get_controller(**controller_kwargs)
|
||||||
mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs)
|
traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs)
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller,
|
|
||||||
ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs)
|
bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller,
|
||||||
return _env
|
**black_box_wrapper_kwargs)
|
||||||
|
|
||||||
|
return bb_env
|
||||||
|
|
||||||
|
|
||||||
def make_mp_env_helper(**kwargs):
|
def make_bb_env_helper(**kwargs):
|
||||||
"""
|
"""
|
||||||
Helper function for registering a DMP gym environments.
|
Helper function for registering a black box gym environment.
|
||||||
Args:
|
Args:
|
||||||
**kwargs: expects at least the following:
|
**kwargs: expects at least the following:
|
||||||
{
|
{
|
||||||
"name": base environment name.
|
"name": base environment name.
|
||||||
"wrappers": list of wrappers (at least an EpisodicWrapper is required),
|
"wrappers": list of wrappers (at least an BlackBoxWrapper is required),
|
||||||
"movement_primitives_kwargs": {
|
"traj_gen_kwargs": {
|
||||||
"movement_primitives_type": type_of_your_movement_primitive,
|
"trajectory_generator_type": type_of_your_movement_primitive,
|
||||||
non default arguments for the movement primitive instance
|
non default arguments for the movement primitive instance
|
||||||
...
|
...
|
||||||
}
|
}
|
||||||
"controller_kwargs": {
|
"controller_kwargs": {
|
||||||
"controller_type": type_of_your_controller,
|
"controller_type": type_of_your_controller,
|
||||||
non default arguments for the controller instance
|
non default arguments for the tracking_controller instance
|
||||||
...
|
...
|
||||||
},
|
},
|
||||||
"basis_generator_kwargs": {
|
"basis_generator_kwargs": {
|
||||||
@ -205,97 +203,19 @@ def make_mp_env_helper(**kwargs):
|
|||||||
seed = kwargs.pop("seed", None)
|
seed = kwargs.pop("seed", None)
|
||||||
wrappers = kwargs.pop("wrappers")
|
wrappers = kwargs.pop("wrappers")
|
||||||
|
|
||||||
mp_kwargs = kwargs.pop("movement_primitives_kwargs")
|
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
|
||||||
ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs')
|
black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {})
|
||||||
contr_kwargs = kwargs.pop("controller_kwargs")
|
contr_kwargs = kwargs.pop("controller_kwargs", {})
|
||||||
phase_kwargs = kwargs.pop("phase_generator_kwargs")
|
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
|
||||||
basis_kwargs = kwargs.pop("basis_generator_kwargs")
|
basis_kwargs = kwargs.pop("basis_generator_kwargs", {})
|
||||||
|
|
||||||
return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs,
|
return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers,
|
||||||
mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs,
|
black_box_wrapper_kwargs=black_box_kwargs,
|
||||||
|
traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs,
|
||||||
|
phase_kwargs=phase_kwargs,
|
||||||
basis_kwargs=basis_kwargs, **kwargs, seed=seed)
|
basis_kwargs=basis_kwargs, **kwargs, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
|
||||||
"""
|
|
||||||
This can also be used standalone for manually building a custom DMP environment.
|
|
||||||
Args:
|
|
||||||
env_id: base_env_name,
|
|
||||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
|
||||||
seed: seed of environment
|
|
||||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
|
||||||
|
|
||||||
Returns: DMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
|
||||||
|
|
||||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
|
||||||
|
|
||||||
return DmpWrapper(_env, **mp_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
|
||||||
"""
|
|
||||||
This can also be used standalone for manually building a custom ProMP environment.
|
|
||||||
Args:
|
|
||||||
env_id: base_env_name,
|
|
||||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
|
||||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
|
||||||
|
|
||||||
Returns: ProMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
|
||||||
|
|
||||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
|
||||||
|
|
||||||
return ProMPWrapper(_env, **mp_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_dmp_env_helper(**kwargs):
|
|
||||||
"""
|
|
||||||
Helper function for registering a DMP gym environments.
|
|
||||||
Args:
|
|
||||||
**kwargs: expects at least the following:
|
|
||||||
{
|
|
||||||
"name": base_env_name,
|
|
||||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
|
||||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns: DMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
seed = kwargs.pop("seed", None)
|
|
||||||
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
|
||||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def make_promp_env_helper(**kwargs):
|
|
||||||
"""
|
|
||||||
Helper function for registering ProMP gym environments.
|
|
||||||
This can also be used standalone for manually building a custom ProMP environment.
|
|
||||||
Args:
|
|
||||||
**kwargs: expects at least the following:
|
|
||||||
{
|
|
||||||
"name": base_env_name,
|
|
||||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
|
||||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
|
||||||
}
|
|
||||||
|
|
||||||
Returns: ProMP wrapped gym env
|
|
||||||
|
|
||||||
"""
|
|
||||||
seed = kwargs.pop("seed", None)
|
|
||||||
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
|
||||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||||
"""
|
"""
|
||||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||||
@ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[
|
|||||||
It can be found in the BaseMP class.
|
It can be found in the BaseMP class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mp_time_limit: max trajectory length of mp in seconds
|
mp_time_limit: max trajectory length of trajectory_generator in seconds
|
||||||
env_time_limit: max trajectory length of DMC environment in seconds
|
env_time_limit: max trajectory length of DMC environment in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
Loading…
Reference in New Issue
Block a user