restructuring
This commit is contained in:
parent
8fe6a83271
commit
02b8a65bab
@ -198,8 +198,8 @@ wrappers = [alr_envs.dmc.suite.ball_in_cup.MPWrapper]
|
||||
mp_kwargs = {...}
|
||||
kwargs = {...}
|
||||
env = alr_envs.make_dmp_env(base_env_id, wrappers=wrappers, seed=1, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP (other mp_kwargs are required):
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
# OR for a deterministic ProMP (other traj_gen_kwargs are required):
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
|
||||
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
|
@ -346,7 +346,7 @@ for _v in _versions:
|
||||
kwargs={
|
||||
"name": f"alr_envs:{_v}",
|
||||
"wrappers": [classic_control.simple_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -386,7 +386,7 @@ register(
|
||||
kwargs={
|
||||
"name": "alr_envs:ViaPointReacher-v0",
|
||||
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -424,7 +424,7 @@ for _v in _versions:
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{_v}",
|
||||
"wrappers": [classic_control.hole_reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -467,7 +467,7 @@ for _v in _versions:
|
||||
kwargs={
|
||||
"name": f"alr_envs:{_v}",
|
||||
"wrappers": [mujoco.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||
"num_basis": 2,
|
||||
"duration": 4,
|
||||
|
@ -1,12 +1,13 @@
|
||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||
from typing import Union, Tuple
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class NewMPWrapper(EpisodicWrapper):
|
||||
class NewMPWrapper(RawInterfaceWrapper):
|
||||
|
||||
def set_active_obs(self):
|
||||
def get_context_mask(self):
|
||||
return np.hstack([
|
||||
[False] * 111, # ant has 111 dimensional observation space !!
|
||||
[True] # goal height
|
||||
|
@ -1,15 +1,11 @@
|
||||
from typing import Tuple, Union
|
||||
from typing import Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class NewMPWrapper(EpisodicWrapper):
|
||||
|
||||
# def __init__(self, replanning_model):
|
||||
# self.replanning_model = replanning_model
|
||||
|
||||
class NewMPWrapper(RawInterfaceWrapper):
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.sim.data.qpos[0:7].copy()
|
||||
@ -18,7 +14,7 @@ class NewMPWrapper(EpisodicWrapper):
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.sim.data.qvel[0:7].copy()
|
||||
|
||||
def set_active_obs(self):
|
||||
def get_context_mask(self):
|
||||
return np.hstack([
|
||||
[False] * 7, # cos
|
||||
[False] * 7, # sin
|
||||
@ -29,11 +25,6 @@ class NewMPWrapper(EpisodicWrapper):
|
||||
[False] # env steps
|
||||
])
|
||||
|
||||
def do_replanning(self, pos, vel, s, a, t, last_replan_step):
|
||||
return False
|
||||
# const = np.arange(0, 1000, 10)
|
||||
# return bool(self.replanning_model(s))
|
||||
|
||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
if self.mp.learn_tau:
|
||||
self.env.env.release_step = action[0] / self.env.dt # Tau value
|
||||
|
@ -1,9 +1,9 @@
|
||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||
from typing import Union, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NewMPWrapper(EpisodicWrapper):
|
||||
class NewMPWrapper(BlackBoxWrapper):
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.sim.data.qpos[3:6].copy()
|
||||
@ -21,7 +21,7 @@ class NewMPWrapper(EpisodicWrapper):
|
||||
# ])
|
||||
|
||||
# Random x goal + random init pos
|
||||
def set_active_obs(self):
|
||||
def get_context_mask(self):
|
||||
return np.hstack([
|
||||
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
||||
[True] * 3, # set to true if randomize initial pos
|
||||
@ -31,7 +31,7 @@ class NewMPWrapper(EpisodicWrapper):
|
||||
|
||||
|
||||
class NewHighCtxtMPWrapper(NewMPWrapper):
|
||||
def set_active_obs(self):
|
||||
def get_context_mask(self):
|
||||
return np.hstack([
|
||||
[False] * (2 + int(not self.env.exclude_current_positions_from_observation)), # position
|
||||
[True] * 3, # set to true if randomize initial pos
|
||||
|
@ -1,9 +1,9 @@
|
||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||
from typing import Union, Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MPWrapper(EpisodicWrapper):
|
||||
class MPWrapper(BlackBoxWrapper):
|
||||
|
||||
@property
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
@ -12,7 +12,7 @@ class MPWrapper(EpisodicWrapper):
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.env.sim.data.qvel.flat[:self.env.n_links]
|
||||
|
||||
def set_active_obs(self):
|
||||
def get_context_mask(self):
|
||||
return np.concatenate([
|
||||
[False] * self.env.n_links, # cos
|
||||
[False] * self.env.n_links, # sin
|
||||
|
@ -15,7 +15,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.ball_in_cup.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -41,7 +41,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.ball_in_cup.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -65,7 +65,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -92,7 +92,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -117,7 +117,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -144,7 +144,7 @@ register(
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.reacher.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
@ -174,7 +174,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -203,7 +203,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -230,7 +230,7 @@ register(
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -259,7 +259,7 @@ register(
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.TwoPolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -286,7 +286,7 @@ register(
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -315,7 +315,7 @@ register(
|
||||
"camera_id": 0,
|
||||
"episode_length": 1000,
|
||||
"wrappers": [suite.cartpole.ThreePolesMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -342,7 +342,7 @@ register(
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [manipulation.reach_site.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
@ -365,7 +365,7 @@ register(
|
||||
# "time_limit": 1,
|
||||
"episode_length": 250,
|
||||
"wrappers": [manipulation.reach_site.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 9,
|
||||
"num_basis": 5,
|
||||
"duration": 10,
|
||||
|
@ -69,7 +69,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
"learn_goal": True, # learn the goal position (recommended)
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "motor", # controller type, 'velocity', 'position', and 'motor' (torque control)
|
||||
"policy_type": "motor", # tracking_controller type, 'velocity', 'position', and 'motor' (torque control)
|
||||
"weights_scale": 1, # scaling of MP weights
|
||||
"goal_scale": 1, # scaling of learned goal position
|
||||
"policy_kwargs": { # only required for torque control/PD-Controller
|
||||
@ -83,8 +83,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
# "frame_skip": 1
|
||||
}
|
||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP (other mp_kwargs are required, see metaworld_examples):
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
# OR for a deterministic ProMP (other traj_gen_kwargs are required, see metaworld_examples):
|
||||
# env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=mp_args)
|
||||
|
||||
# This renders the full MP trajectory
|
||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||
|
@ -73,12 +73,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
"width": 0.025, # width of the basis functions
|
||||
"zero_start": True, # start from current environment position if True
|
||||
"weights_scale": 1, # scaling of MP weights
|
||||
"policy_type": "metaworld", # custom controller type for metaworld environments
|
||||
"policy_type": "metaworld", # custom tracking_controller type for metaworld environments
|
||||
}
|
||||
|
||||
env = alr_envs.make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# OR for a DMP (other mp_kwargs are required, see dmc_examples):
|
||||
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a DMP (other traj_gen_kwargs are required, see dmc_examples):
|
||||
# env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs, **kwargs)
|
||||
|
||||
# This renders the full MP trajectory
|
||||
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||
|
@ -57,7 +57,7 @@ def example_custom_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1, iterations=
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# Changing the mp_kwargs is possible by providing them to gym.
|
||||
# Changing the traj_gen_kwargs is possible by providing them to gym.
|
||||
# E.g. here by providing way to many basis functions
|
||||
mp_kwargs = {
|
||||
"num_dof": 5,
|
||||
@ -126,7 +126,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
}
|
||||
env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# OR for a deterministic ProMP:
|
||||
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
# env = make_promp_env(base_env, wrappers=wrappers, seed=seed, traj_gen_kwargs=traj_gen_kwargs)
|
||||
|
||||
if render:
|
||||
env.render(mode="human")
|
||||
|
@ -4,7 +4,7 @@ import alr_envs
|
||||
def example_mp(env_name, seed=1):
|
||||
"""
|
||||
Example for running a motion primitive based version of a OpenAI-gym environment, which is already registered.
|
||||
For more information on motion primitive specific stuff, look at the mp examples.
|
||||
For more information on motion primitive specific stuff, look at the trajectory_generator examples.
|
||||
Args:
|
||||
env_name: ProMP env_id
|
||||
seed: seed
|
||||
|
@ -8,7 +8,7 @@ from alr_envs.utils.make_env_helpers import make_promp_env
|
||||
|
||||
def visualize(env):
|
||||
t = env.t
|
||||
pos_features = env.mp.basis_generator.basis(t)
|
||||
pos_features = env.trajectory_generator.basis_generator.basis(t)
|
||||
plt.plot(t, pos_features)
|
||||
plt.show()
|
||||
|
||||
|
@ -19,7 +19,7 @@ for _task in _goal_change_envs:
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
@ -42,7 +42,7 @@ for _task in _object_change_envs:
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
@ -75,7 +75,7 @@ for _task in _goal_and_object_change_envs:
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
@ -98,7 +98,7 @@ for _task in _goal_and_endeffector_change_envs:
|
||||
kwargs={
|
||||
"name": _task,
|
||||
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 6.25,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Tuple
|
||||
from abc import ABC
|
||||
from typing import Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -7,77 +7,77 @@ from gym import spaces
|
||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||
|
||||
from alr_envs.mp.controllers.base_controller import BaseController
|
||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
||||
class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
||||
|
||||
def __init__(self,
|
||||
env: RawInterfaceWrapper,
|
||||
trajectory_generator: MPInterface, tracking_controller: BaseController,
|
||||
duration: float, verbose: int = 1, sequencing=True, reward_aggregation: callable = np.sum):
|
||||
"""
|
||||
Base class for movement primitive based gym.Wrapper implementations.
|
||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||
|
||||
Args:
|
||||
env: The (wrapped) environment this wrapper is applied on
|
||||
num_dof: Dimension of the action space of the wrapped env
|
||||
num_basis: Number of basis functions per dof
|
||||
trajectory_generator: Generates the full or partial trajectory
|
||||
tracking_controller: Translates the desired trajectory to raw action sequences
|
||||
duration: Length of the trajectory of the movement primitive in seconds
|
||||
controller: Type or object defining the policy that is used to generate action based on the trajectory
|
||||
weight_scale: Scaling parameter for the actions given to this wrapper
|
||||
render_mode: Equivalent to gym render mode
|
||||
verbose: level of detail for returned values in info dict.
|
||||
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
|
||||
reward, default summation over all values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
mp: MPInterface,
|
||||
controller: BaseController,
|
||||
duration: float,
|
||||
render_mode: str = None,
|
||||
verbose: int = 1,
|
||||
weight_scale: float = 1,
|
||||
sequencing=True,
|
||||
reward_aggregation=np.mean,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.env = env
|
||||
try:
|
||||
self.dt = env.dt
|
||||
except AttributeError:
|
||||
raise AttributeError("step based environment needs to have a function 'dt' ")
|
||||
self.duration = duration
|
||||
self.traj_steps = int(duration / self.dt)
|
||||
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||
# duration = self.env.max_episode_steps * self.dt
|
||||
|
||||
self.mp = mp
|
||||
self.env = env
|
||||
self.controller = controller
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
# rendering
|
||||
self.render_mode = render_mode
|
||||
self.render_kwargs = {}
|
||||
# trajectory generation
|
||||
self.trajectory_generator = trajectory_generator
|
||||
self.tracking_controller = tracking_controller
|
||||
# self.weight_scale = weight_scale
|
||||
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||
self.mp.set_mp_times(self.time_steps)
|
||||
# self.mp.set_mp_duration(self.time_steps, dt)
|
||||
# action_bounds = np.inf * np.ones((np.prod(self.mp.num_params)))
|
||||
self.mp_action_space = self.get_mp_action_space()
|
||||
self.trajectory_generator.set_mp_times(self.time_steps)
|
||||
# self.trajectory_generator.set_mp_duration(self.time_steps, dt)
|
||||
# action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
||||
# spaces
|
||||
self.mp_action_space = self.get_mp_action_space()
|
||||
self.action_space = self.get_action_space()
|
||||
self.active_obs = self.set_active_obs()
|
||||
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.active_obs],
|
||||
high=self.env.observation_space.high[self.active_obs],
|
||||
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask],
|
||||
high=self.env.observation_space.high[self.env.context_mask],
|
||||
dtype=self.env.observation_space.dtype)
|
||||
|
||||
# rendering
|
||||
self.render_mode = None
|
||||
self.render_kwargs = {}
|
||||
|
||||
self.verbose = verbose
|
||||
|
||||
@property
|
||||
def dt(self):
|
||||
return self.env.dt
|
||||
|
||||
def observation(self, observation):
|
||||
return observation[self.env.context_mask]
|
||||
|
||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
|
||||
# the beginning of the array.
|
||||
ignore_indices = int(self.mp.learn_tau) + int(self.mp.learn_delay)
|
||||
scaled_mp_params = action.copy()
|
||||
scaled_mp_params[ignore_indices:] *= self.weight_scale
|
||||
self.mp.set_params(np.clip(scaled_mp_params, self.mp_action_space.low, self.mp_action_space.high))
|
||||
self.mp.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos, bc_vel=self.current_vel)
|
||||
traj_dict = self.mp.get_mp_trajs(get_pos=True, get_vel=True)
|
||||
# ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
|
||||
# scaled_mp_params = action.copy()
|
||||
# scaled_mp_params[ignore_indices:] *= self.weight_scale
|
||||
|
||||
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
|
||||
self.trajectory_generator.set_params(clipped_params)
|
||||
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
|
||||
bc_vel=self.current_vel)
|
||||
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
|
||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||
|
||||
trajectory = trajectory_tensor.numpy()
|
||||
@ -86,13 +86,13 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
||||
# TODO: Do we need this or does mp_pytorch have this?
|
||||
if self.post_traj_steps > 0:
|
||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dof))])
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.trajectory_generator.num_dof))])
|
||||
|
||||
return trajectory, velocity
|
||||
|
||||
def get_mp_action_space(self):
|
||||
"""This function can be used to set up an individual space for the parameters of the mp."""
|
||||
min_action_bounds, max_action_bounds = self.mp.get_param_bounds()
|
||||
"""This function can be used to set up an individual space for the parameters of the trajectory_generator."""
|
||||
min_action_bounds, max_action_bounds = self.trajectory_generator.get_param_bounds()
|
||||
mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
|
||||
dtype=np.float32)
|
||||
return mp_action_space
|
||||
@ -109,71 +109,6 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
||||
except AttributeError:
|
||||
return self.get_mp_action_space()
|
||||
|
||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
"""
|
||||
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
||||
include other actions like ball releasing time for the beer pong environment.
|
||||
This only needs to be overwritten if the action space is modified.
|
||||
Args:
|
||||
action: a vector instance of the whole action space, includes mp parameters and additional parameters if
|
||||
specified, else only mp parameters
|
||||
|
||||
Returns:
|
||||
Tuple: mp_arguments and other arguments
|
||||
"""
|
||||
return action, None
|
||||
|
||||
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
|
||||
np.ndarray]:
|
||||
"""
|
||||
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
|
||||
Beerpong env. The parameters used should not be part of the motion primitive parameters.
|
||||
Returns step_action by default, can be overwritten in individual mp_wrappers.
|
||||
Args:
|
||||
t: the current time step of the episode
|
||||
env_spec_params: the environment specific parameter, as defined in fucntion _episode_callback
|
||||
(e.g. ball release time in Beer Pong)
|
||||
step_action: the current step-based action
|
||||
|
||||
Returns:
|
||||
modified step action
|
||||
"""
|
||||
return step_action
|
||||
|
||||
@abstractmethod
|
||||
def set_active_obs(self) -> np.ndarray:
|
||||
"""
|
||||
This function defines the contexts. The contexts are defined as specific observations.
|
||||
Returns:
|
||||
boolearn array representing the indices of the observations
|
||||
|
||||
"""
|
||||
return np.ones(self.env.observation_space.shape[0], dtype=bool)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
"""
|
||||
Returns the current position of the action/control dimension.
|
||||
The dimensionality has to match the action/control dimension.
|
||||
This is not required when exclusively using velocity control,
|
||||
it should, however, be implemented regardless.
|
||||
E.g. The joint positions that are directly or indirectly controlled by the action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
"""
|
||||
Returns the current velocity of the action/control dimension.
|
||||
The dimensionality has to match the action/control dimension.
|
||||
This is not required when exclusively using position control,
|
||||
it should, however, be implemented regardless.
|
||||
E.g. The joint velocities that are directly or indirectly controlled by the action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||
# TODO: Think about sequencing
|
||||
@ -184,46 +119,52 @@ class EpisodicWrapper(gym.Env, gym.wrappers.TransformReward, ABC):
|
||||
|
||||
# TODO
|
||||
# self.time_steps = np.linspace(0, learned_duration, self.traj_steps)
|
||||
# self.mp.set_mp_times(self.time_steps)
|
||||
# self.trajectory_generator.set_mp_times(self.time_steps)
|
||||
|
||||
trajectory_length = len(trajectory)
|
||||
rewards = np.zeros(shape=(trajectory_length,))
|
||||
if self.verbose >= 2:
|
||||
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
|
||||
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
|
||||
dtype=self.env.observation_space.dtype)
|
||||
rewards = np.zeros(shape=(trajectory_length,))
|
||||
trajectory_return = 0
|
||||
|
||||
infos = dict()
|
||||
done = False
|
||||
|
||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||
step_action = self.controller.get_action(pos_vel[0], pos_vel[1], self.current_pos, self.current_vel)
|
||||
step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos,
|
||||
self.current_vel)
|
||||
step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info
|
||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
# print('step/clipped action ratio: ', step_action/c_action)
|
||||
obs, c_reward, done, info = self.env.step(c_action)
|
||||
rewards[t] = c_reward
|
||||
|
||||
if self.verbose >= 2:
|
||||
actions[t, :] = c_action
|
||||
rewards[t] = c_reward
|
||||
observations[t, :] = obs
|
||||
trajectory_return += c_reward
|
||||
|
||||
for k, v in info.items():
|
||||
elems = infos.get(k, [None] * trajectory_length)
|
||||
elems[t] = v
|
||||
infos[k] = elems
|
||||
# infos['step_infos'].append(info)
|
||||
if self.render_mode:
|
||||
|
||||
if self.render_mode is not None:
|
||||
self.render(mode=self.render_mode, **self.render_kwargs)
|
||||
if done or do_replanning(kwargs):
|
||||
|
||||
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
|
||||
break
|
||||
|
||||
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||
|
||||
if self.verbose >= 2:
|
||||
infos['trajectory'] = trajectory
|
||||
infos['step_actions'] = actions[:t + 1]
|
||||
infos['step_observations'] = observations[:t + 1]
|
||||
infos['step_rewards'] = rewards[:t + 1]
|
||||
|
||||
infos['trajectory_length'] = t + 1
|
||||
done = True
|
||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||
return self.get_observation_from_step(obs), trajectory_return, done, infos
|
||||
|
||||
def reset(self):
|
@ -6,8 +6,8 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
||||
class MetaWorldController(BaseController):
|
||||
"""
|
||||
A Metaworld Controller. Using position and velocity information from a provided environment,
|
||||
the controller calculates a response based on the desired position and velocity.
|
||||
Unlike the other Controllers, this is a special controller for MetaWorld environments.
|
||||
the tracking_controller calculates a response based on the desired position and velocity.
|
||||
Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
|
||||
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
||||
|
||||
:param env: A position environment
|
||||
|
@ -6,7 +6,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
||||
class PDController(BaseController):
|
||||
"""
|
||||
A PD-Controller. Using position and velocity information from a provided environment,
|
||||
the controller calculates a response based on the desired position and velocity
|
||||
the tracking_controller calculates a response based on the desired position and velocity
|
||||
|
||||
:param env: A position environment
|
||||
:param p_gains: Factors for the proportional gains
|
||||
|
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
||||
|
||||
class PosController(BaseController):
|
||||
"""
|
||||
A Position Controller. The controller calculates a response only based on the desired position.
|
||||
A Position Controller. The tracking_controller calculates a response only based on the desired position.
|
||||
"""
|
||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||
return des_pos
|
||||
|
@ -3,7 +3,7 @@ from alr_envs.mp.controllers.base_controller import BaseController
|
||||
|
||||
class VelController(BaseController):
|
||||
"""
|
||||
A Velocity Controller. The controller calculates a response only based on the desired velocity.
|
||||
A Velocity Controller. The tracking_controller calculates a response only based on the desired velocity.
|
||||
"""
|
||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||
return des_vel
|
||||
|
@ -7,16 +7,16 @@ from mp_pytorch.basis_gn.basis_generator import BasisGenerator
|
||||
ALL_TYPES = ["promp", "dmp", "idmp"]
|
||||
|
||||
|
||||
def get_movement_primitive(
|
||||
movement_primitives_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
||||
def get_trajectory_generator(
|
||||
trajectory_generator_type: str, action_dim: int, basis_generator: BasisGenerator, **kwargs
|
||||
):
|
||||
movement_primitives_type = movement_primitives_type.lower()
|
||||
if movement_primitives_type == "promp":
|
||||
trajectory_generator_type = trajectory_generator_type.lower()
|
||||
if trajectory_generator_type == "promp":
|
||||
return ProMP(basis_generator, action_dim, **kwargs)
|
||||
elif movement_primitives_type == "dmp":
|
||||
elif trajectory_generator_type == "dmp":
|
||||
return DMP(basis_generator, action_dim, **kwargs)
|
||||
elif movement_primitives_type == 'idmp':
|
||||
elif trajectory_generator_type == 'idmp':
|
||||
return IDMP(basis_generator, action_dim, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Specified movement primitive type {movement_primitives_type} not supported, "
|
||||
raise ValueError(f"Specified movement primitive type {trajectory_generator_type} not supported, "
|
||||
f"please choose one of {ALL_TYPES}.")
|
88
alr_envs/mp/raw_interface_wrapper.py
Normal file
88
alr_envs/mp/raw_interface_wrapper.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import Union, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class RawInterfaceWrapper(gym.Wrapper):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def context_mask(self) -> np.ndarray:
|
||||
"""
|
||||
This function defines the contexts. The contexts are defined as specific observations.
|
||||
Returns:
|
||||
bool array representing the indices of the observations
|
||||
|
||||
"""
|
||||
return np.ones(self.env.observation_space.shape[0], dtype=bool)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
"""
|
||||
Returns the current position of the action/control dimension.
|
||||
The dimensionality has to match the action/control dimension.
|
||||
This is not required when exclusively using velocity control,
|
||||
it should, however, be implemented regardless.
|
||||
E.g. The joint positions that are directly or indirectly controlled by the action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
"""
|
||||
Returns the current velocity of the action/control dimension.
|
||||
The dimensionality has to match the action/control dimension.
|
||||
This is not required when exclusively using position control,
|
||||
it should, however, be implemented regardless.
|
||||
E.g. The joint velocities that are directly or indirectly controlled by the action.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dt(self) -> float:
|
||||
"""
|
||||
Control frequency of the environment
|
||||
Returns: float
|
||||
|
||||
"""
|
||||
|
||||
def do_replanning(self, pos, vel, s, a, t):
|
||||
# return t % 100 == 0
|
||||
# return bool(self.replanning_model(s))
|
||||
return False
|
||||
|
||||
def _episode_callback(self, action: np.ndarray) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||
"""
|
||||
Used to extract the parameters for the motion primitive and other parameters from an action array which might
|
||||
include other actions like ball releasing time for the beer pong environment.
|
||||
This only needs to be overwritten if the action space is modified.
|
||||
Args:
|
||||
action: a vector instance of the whole action space, includes trajectory_generator parameters and additional parameters if
|
||||
specified, else only trajectory_generator parameters
|
||||
|
||||
Returns:
|
||||
Tuple: mp_arguments and other arguments
|
||||
"""
|
||||
return action, None
|
||||
|
||||
def _step_callback(self, t: int, env_spec_params: Union[np.ndarray, None], step_action: np.ndarray) -> Union[
|
||||
np.ndarray]:
|
||||
"""
|
||||
This function can be used to modify the step_action with additional parameters e.g. releasing the ball in the
|
||||
Beerpong env. The parameters used should not be part of the motion primitive parameters.
|
||||
Returns step_action by default, can be overwritten in individual mp_wrappers.
|
||||
Args:
|
||||
t: the current time step of the episode
|
||||
env_spec_params: the environment specific parameter, as defined in function _episode_callback
|
||||
(e.g. ball release time in Beer Pong)
|
||||
step_action: the current step-based action
|
||||
|
||||
Returns:
|
||||
modified step action
|
||||
"""
|
||||
return step_action
|
@ -21,7 +21,7 @@ register(
|
||||
kwargs={
|
||||
"name": "alr_envs:MountainCarContinuous-v1",
|
||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 4,
|
||||
"duration": 2,
|
||||
@ -43,7 +43,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 1,
|
||||
"num_basis": 4,
|
||||
"duration": 19.98,
|
||||
@ -65,7 +65,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.mujoco:Reacher-v2",
|
||||
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 6,
|
||||
"duration": 1,
|
||||
@ -87,7 +87,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -105,7 +105,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchSlide-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -123,7 +123,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -141,7 +141,7 @@ register(
|
||||
kwargs={
|
||||
"name": "gym.envs.robotics:FetchReach-v1",
|
||||
"wrappers": [FlattenObservation, robotics.fetch.MPWrapper],
|
||||
"mp_kwargs": {
|
||||
"traj_gen_kwargs": {
|
||||
"num_dof": 4,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
|
@ -4,17 +4,15 @@ from typing import Iterable, Type, Union, Mapping, MutableMapping
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym.envs.registration import EnvSpec
|
||||
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
from mp_env_api.mp_wrappers.promp_wrapper import ProMPWrapper
|
||||
from mp_pytorch import MPInterface
|
||||
|
||||
from alr_envs.mp.basis_generator_factory import get_basis_generator
|
||||
from alr_envs.mp.black_box_wrapper import BlackBoxWrapper
|
||||
from alr_envs.mp.controllers.base_controller import BaseController
|
||||
from alr_envs.mp.controllers.controller_factory import get_controller
|
||||
from alr_envs.mp.mp_factory import get_movement_primitive
|
||||
from alr_envs.mp.episodic_wrapper import EpisodicWrapper
|
||||
from alr_envs.mp.mp_factory import get_trajectory_generator
|
||||
from alr_envs.mp.phase_generator_factory import get_phase_generator
|
||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
|
||||
@ -100,9 +98,8 @@ def make(env_id: str, seed, **kwargs):
|
||||
|
||||
|
||||
def _make_wrapped_env(
|
||||
env_id: str, wrappers: Iterable[Type[gym.Wrapper]], mp: MPInterface, controller: BaseController,
|
||||
ep_wrapper_kwargs: Mapping, seed=1, **kwargs
|
||||
):
|
||||
env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs
|
||||
):
|
||||
"""
|
||||
Helper function for creating a wrapped gym environment using MPs.
|
||||
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
|
||||
@ -118,73 +115,74 @@ def _make_wrapped_env(
|
||||
"""
|
||||
# _env = gym.make(env_id)
|
||||
_env = make(env_id, seed, **kwargs)
|
||||
has_episodic_wrapper = False
|
||||
has_black_box_wrapper = False
|
||||
for w in wrappers:
|
||||
# only wrap the environment if not EpisodicWrapper, e.g. for vision
|
||||
if not issubclass(w, EpisodicWrapper):
|
||||
# only wrap the environment if not BlackBoxWrapper, e.g. for vision
|
||||
if issubclass(w, RawInterfaceWrapper):
|
||||
has_black_box_wrapper = True
|
||||
_env = w(_env)
|
||||
else: # if EpisodicWrapper, use specific constructor
|
||||
has_episodic_wrapper = True
|
||||
_env = w(env=_env, mp=mp, controller=controller, **ep_wrapper_kwargs)
|
||||
if not has_episodic_wrapper:
|
||||
raise ValueError("An EpisodicWrapper is required in order to leverage movement primitive environments.")
|
||||
if not has_black_box_wrapper:
|
||||
raise ValueError("An RawInterfaceWrapper is required in order to leverage movement primitive environments.")
|
||||
return _env
|
||||
|
||||
|
||||
def make_mp_from_kwargs(
|
||||
env_id: str, wrappers: Iterable, ep_wrapper_kwargs: MutableMapping, mp_kwargs: MutableMapping,
|
||||
def make_bb_env(
|
||||
env_id: str, wrappers: Iterable, black_box_wrapper_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
||||
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed=1,
|
||||
sequenced=False, **kwargs
|
||||
):
|
||||
sequenced=False, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom DMP environment.
|
||||
Args:
|
||||
ep_wrapper_kwargs:
|
||||
basis_kwargs:
|
||||
phase_kwargs:
|
||||
controller_kwargs:
|
||||
black_box_wrapper_kwargs: kwargs for the black-box wrapper
|
||||
basis_kwargs: kwargs for the basis generator
|
||||
phase_kwargs: kwargs for the phase generator
|
||||
controller_kwargs: kwargs for the tracking controller
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an EpisodicWrapper),
|
||||
wrappers: list of wrappers (at least an BlackBoxWrapper),
|
||||
seed: seed of environment
|
||||
sequenced: When true, this allows to sequence multiple ProMPs by specifying the duration of each sub-trajectory,
|
||||
this behavior is much closer to step based learning.
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
traj_gen_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
dummy_env = make(env_id, seed)
|
||||
if ep_wrapper_kwargs.get('duration', None) is None:
|
||||
ep_wrapper_kwargs['duration'] = dummy_env.spec.max_episode_steps * dummy_env.dt
|
||||
_verify_time_limit(traj_gen_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
if black_box_wrapper_kwargs.get('duration', None) is None:
|
||||
black_box_wrapper_kwargs['duration'] = _env.spec.max_episode_steps * _env.dt
|
||||
if phase_kwargs.get('tau', None) is None:
|
||||
phase_kwargs['tau'] = ep_wrapper_kwargs['duration']
|
||||
mp_kwargs['action_dim'] = mp_kwargs.get('action_dim', np.prod(dummy_env.action_space.shape).item())
|
||||
phase_kwargs['tau'] = black_box_wrapper_kwargs['duration']
|
||||
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(_env.action_space.shape).item())
|
||||
|
||||
phase_gen = get_phase_generator(**phase_kwargs)
|
||||
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
|
||||
controller = get_controller(**controller_kwargs)
|
||||
mp = get_movement_primitive(basis_generator=basis_gen, **mp_kwargs)
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, mp=mp, controller=controller,
|
||||
ep_wrapper_kwargs=ep_wrapper_kwargs, seed=seed, **kwargs)
|
||||
return _env
|
||||
traj_gen = get_trajectory_generator(basis_generator=basis_gen, **traj_gen_kwargs)
|
||||
|
||||
bb_env = BlackBoxWrapper(_env, trajectory_generator=traj_gen, tracking_controller=controller,
|
||||
**black_box_wrapper_kwargs)
|
||||
|
||||
return bb_env
|
||||
|
||||
|
||||
def make_mp_env_helper(**kwargs):
|
||||
def make_bb_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering a DMP gym environments.
|
||||
Helper function for registering a black box gym environment.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base environment name.
|
||||
"wrappers": list of wrappers (at least an EpisodicWrapper is required),
|
||||
"movement_primitives_kwargs": {
|
||||
"movement_primitives_type": type_of_your_movement_primitive,
|
||||
"wrappers": list of wrappers (at least an BlackBoxWrapper is required),
|
||||
"traj_gen_kwargs": {
|
||||
"trajectory_generator_type": type_of_your_movement_primitive,
|
||||
non default arguments for the movement primitive instance
|
||||
...
|
||||
}
|
||||
"controller_kwargs": {
|
||||
"controller_type": type_of_your_controller,
|
||||
non default arguments for the controller instance
|
||||
non default arguments for the tracking_controller instance
|
||||
...
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
@ -205,97 +203,19 @@ def make_mp_env_helper(**kwargs):
|
||||
seed = kwargs.pop("seed", None)
|
||||
wrappers = kwargs.pop("wrappers")
|
||||
|
||||
mp_kwargs = kwargs.pop("movement_primitives_kwargs")
|
||||
ep_wrapper_kwargs = kwargs.pop('ep_wrapper_kwargs')
|
||||
contr_kwargs = kwargs.pop("controller_kwargs")
|
||||
phase_kwargs = kwargs.pop("phase_generator_kwargs")
|
||||
basis_kwargs = kwargs.pop("basis_generator_kwargs")
|
||||
traj_gen_kwargs = kwargs.pop("traj_gen_kwargs", {})
|
||||
black_box_kwargs = kwargs.pop('black_box_wrapper_kwargs', {})
|
||||
contr_kwargs = kwargs.pop("controller_kwargs", {})
|
||||
phase_kwargs = kwargs.pop("phase_generator_kwargs", {})
|
||||
basis_kwargs = kwargs.pop("basis_generator_kwargs", {})
|
||||
|
||||
return make_mp_from_kwargs(env_id=kwargs.pop("name"), wrappers=wrappers, ep_wrapper_kwargs=ep_wrapper_kwargs,
|
||||
mp_kwargs=mp_kwargs, controller_kwargs=contr_kwargs, phase_kwargs=phase_kwargs,
|
||||
return make_bb_env(env_id=kwargs.pop("name"), wrappers=wrappers,
|
||||
black_box_wrapper_kwargs=black_box_kwargs,
|
||||
traj_gen_kwargs=traj_gen_kwargs, controller_kwargs=contr_kwargs,
|
||||
phase_kwargs=phase_kwargs,
|
||||
basis_kwargs=basis_kwargs, **kwargs, seed=seed)
|
||||
|
||||
|
||||
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom DMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
seed: seed of environment
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return DmpWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_promp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
|
||||
Returns: ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
_verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
|
||||
_verify_dof(_env, mp_kwargs.get("num_dof"))
|
||||
|
||||
return ProMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_dmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering a DMP gym environments.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
seed = kwargs.pop("seed", None)
|
||||
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def make_promp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering ProMP gym environments.
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
}
|
||||
|
||||
Returns: ProMP wrapped gym env
|
||||
|
||||
"""
|
||||
seed = kwargs.pop("seed", None)
|
||||
return make_promp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||
mp_kwargs=kwargs.pop("mp_kwargs"), **kwargs)
|
||||
|
||||
|
||||
def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||
"""
|
||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||
@ -304,7 +224,7 @@ def _verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[
|
||||
It can be found in the BaseMP class.
|
||||
|
||||
Args:
|
||||
mp_time_limit: max trajectory length of mp in seconds
|
||||
mp_time_limit: max trajectory length of trajectory_generator in seconds
|
||||
env_time_limit: max trajectory length of DMC environment in seconds
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user