add invalid trajectory callback & invalid traj return & register all 3 variantes of table tennis tasks

This commit is contained in:
Hongyi Zhou 2022-12-01 11:28:03 +01:00
parent 28aa430fd2
commit f376772c22
9 changed files with 151 additions and 162 deletions

View File

@ -63,14 +63,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.traj_gen_action_space = self._get_traj_gen_action_space()
self.action_space = self._get_action_space()
# no goal learning
# tricky_action_upperbound = [np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# tricky_action_lowerbound = [-np.inf] * (self.traj_gen_action_space.shape[0] - 7)
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
self.action_space.low[0] = 0.8
self.action_space.high[0] = 1.5
self.action_space.low[1] = 0.05
self.action_space.high[1] = 0.15
self.observation_space = self._get_observation_space()
# rendering
@ -93,8 +85,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
return observation.astype(self.observation_space.dtype)
def get_trajectory(self, action: np.ndarray) -> Tuple:
# duration = self.duration
duration = self.duration - self.current_traj_steps * self.dt
duration = self.duration
# duration = self.duration - self.current_traj_steps * self.dt
if self.learn_sub_trajectories:
duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -157,8 +149,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
# TODO remove this part, right now only needed for beer pong
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(action)
traj_is_valid = self.env.episode_callback(action, position, velocity)
traj_is_valid = self.env.preprocessing_and_validity_callback(action, position, velocity)
# insert validation here
trajectory_length = len(position)
rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2:
@ -169,7 +161,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos = dict()
done = False
if traj_is_valid:
if not traj_is_valid:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos
else:
self.plan_steps += 1
for t, (pos, vel) in enumerate(zip(position, velocity)):
current_pos = self.current_pos
@ -215,10 +211,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
infos['trajectory_length'] = t + 1
trajectory_return = self.reward_aggregation(rewards[:t + 1])
return self.observation(obs), trajectory_return, done, infos
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
self.return_context_observation)
return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout.

View File

@ -52,6 +52,19 @@ class RawInterfaceWrapper(gym.Wrapper):
"""
return self.env.dt
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) \
-> Tuple[bool, np.ndarray, np.ndarray]:
"""
Used to preprocess the action and check if the desired trajectory is valid.
"""
return True, pos_traj, vel_traj
def set_episode_arguments(self, action, pos_traj, vel_traj):
"""
Used to set the arguments for env that valid for the whole episode
"""
return pos_traj, vel_traj
def episode_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.array) -> Tuple[bool]:
"""
Used to extract the parameters for the movement primitive and other parameters from an action array which might
@ -68,7 +81,6 @@ class RawInterfaceWrapper(gym.Wrapper):
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
"""
Used to return a fake return from the environment if the desired trajectory is invalid.
Used to return a artificial return from the env if the desired trajectory is invalid.
"""
obs = np.zeros(1)
return obs, 0, True, {}
return np.zeros(1), 0, True, {}

View File

@ -18,6 +18,8 @@ from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \
BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING
from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching, \
MAX_EPISODE_STEPS_TABLE_TENNIS
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
@ -248,17 +250,28 @@ for ctxt_dim in [2, 4]:
register(
id='TableTennis{}D-v0'.format(ctxt_dim),
entry_point='fancy_gym.envs.mujoco:TableTennisEnv',
max_episode_steps=350,
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
kwargs={
"ctxt_dim": ctxt_dim,
'frame_skip': 4,
'enable_wind': False,
'enable_switching_goal': False,
'enable_air': False,
'enable_artifical_wind': False,
'goal_switching_step': None,
'enable_artificial_wind': False,
}
)
register(
id='TableTennisWind-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisWind',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
)
register(
id='TableTennisGoalSwitching-v0',
entry_point='fancy_gym.envs.mujoco:TableTennisGoalSwitching',
max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS,
)
# movement Primitive Environments
## Simple Reacher
@ -529,17 +542,22 @@ for _v in _versions:
ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProDMP"].append(_env_id)
## Table Tennis
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0']
_versions = ['TableTennis2D-v0', 'TableTennis4D-v0', 'TableTennisWind-v0', 'TableTennisGoalSwitching-v0']
for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_tt_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.MPWrapper)
if _v == 'TableTennisWind-v0':
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
else:
kwargs_dict_tt_promp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
kwargs_dict_tt_promp['name'] = _v
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
@ -556,7 +574,10 @@ for _v in _versions:
_name = _v.split("-")
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
kwargs_dict_tt_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.MPWrapper)
if _v == 'TableTennisWind-v0':
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TTVelObs_MPWrapper)
else:
kwargs_dict_tt_prodmp['wrappers'].append(mujoco.table_tennis.TT_MPWrapper)
kwargs_dict_tt_prodmp['name'] = _v
kwargs_dict_tt_prodmp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
kwargs_dict_tt_prodmp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
@ -564,12 +585,13 @@ for _v in _versions:
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.0
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = False
kwargs_dict_tt_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
kwargs_dict_tt_prodmp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
kwargs_dict_tt_prodmp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_tau'] = True
kwargs_dict_tt_prodmp['phase_generator_kwargs']['learn_delay'] = True
kwargs_dict_tt_prodmp['basis_generator_kwargs']['num_basis'] = 2
kwargs_dict_tt_prodmp['basis_generator_kwargs']['alpha'] = 25.
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
#kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0

View File

@ -8,4 +8,4 @@ from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
from .table_tennis.table_tennis_env import TableTennisEnv
from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching

View File

@ -1 +1 @@
from .mp_wrapper import MPWrapper
from .mp_wrapper import TT_MPWrapper, TTVelObs_MPWrapper

View File

@ -6,7 +6,7 @@ from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
class MPWrapper(RawInterfaceWrapper):
class TT_MPWrapper(RawInterfaceWrapper):
# Random x goal + random init pos
@property
@ -29,48 +29,26 @@ class MPWrapper(RawInterfaceWrapper):
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
return self.data.qvel[:7].copy()
def check_time_validity(self, action):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def preprocessing_and_validity_callback(self, action, pos_traj, vel_traj):
return self.check_traj_validity(action, pos_traj, vel_traj)
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
-> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty
obs = np.concatenate([self.get_obs(), np.array([0])])
return obs, -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}
def set_episode_arguments(self, action, pos_traj, vel_traj):
return pos_traj, vel_traj
def episode_callback(self, action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False
return True
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray,
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error
obs = np.concatenate([self.get_obs(), np.array([0])])
if return_contextual_obs:
obs = self.get_obs()
return obs, -invalid_penalty, True, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
"is_success": [False],
'trajectory_length': 1,
"num_steps": [1]
}
return self.get_invalid_traj_step_return(action, pos_traj, vel_traj, return_contextual_obs)
class TTVelObs_MPWrapper(TT_MPWrapper):
@property
def context_mask(self):
return np.hstack([
[False] * 7, # joints position
[False] * 7, # joints velocity
[True] * 2, # position ball x, y
[False] * 1, # position ball z
[True] * 3, # velocity ball x, y, z
[True] * 2, # target landing position
# [True] * 1, # time
])

View File

@ -5,6 +5,7 @@ from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import check_init_state_validity, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
import mujoco
@ -21,13 +22,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
"""
7 DoF table tennis environment
"""
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
enable_switching_goal: bool = False,
enable_wind: bool = False,
enable_artifical_wind: bool = False,
enable_magnus: bool = False,
enable_air: bool = False):
goal_switching_step: int = None,
enable_artificial_wind: bool = False):
utils.EzPickle.__init__(**locals())
self._steps = 0
@ -47,12 +44,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._ball_traj = []
self._racket_traj = []
self._goal_switching_step = goal_switching_step
self._enable_goal_switching = enable_switching_goal
self._enable_artificial_wind = enable_artificial_wind
self._enable_artifical_wind = enable_artifical_wind
self._artifical_force = 0.
self._artificial_force = 0.
MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
@ -62,20 +58,13 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4:
self.context_bounds = CONTEXT_BOUNDS_4DIMS
if self._enable_goal_switching:
if self._goal_switching_step is not None:
self.context_bounds = CONTEXT_BOUNDS_SWICHING
else:
raise NotImplementedError
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
# complex dynamics settings
if enable_air:
self.model.opt.density = 1.225
self.model.opt.viscosity = 2.27e-5
self._enable_wind = enable_wind
self._enable_magnus = enable_magnus
self._wind_vel = np.zeros(3)
def _set_ids(self):
@ -92,9 +81,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False
if self._enable_goal_switching:
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
if self._steps == self._goal_switching_step and self.np_random.uniform(0, 1) < 0.5:
new_goal_pos = self._generate_goal_pos(random=True)
new_goal_pos[1] = -new_goal_pos[1]
self._goal_pos = new_goal_pos
@ -102,8 +89,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
mujoco.mj_forward(self.model, self.data)
for _ in range(self.frame_skip):
if self._enable_artifical_wind:
self.data.qfrc_applied[-2] = self._artifical_force
if self._enable_artificial_wind:
self.data.qfrc_applied[-2] = self._artificial_force
try:
self.do_simulation(action, 1)
except Exception as e:
@ -163,7 +150,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def reset_model(self):
self._steps = 0
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
# self._init_ball_state[2] = 1.85
self._goal_pos = self._generate_goal_pos(random=True)
self.data.joint("tar_x").qpos = self._init_ball_state[0]
self.data.joint("tar_y").qpos = self._init_ball_state[1]
@ -172,19 +158,16 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_y").qvel = self._init_ball_state[4]
self.data.joint("tar_z").qvel = self._init_ball_state[5]
if self._enable_artifical_wind:
self._artifical_force = self.np_random.uniform(low=-0.1, high=0.1)
if self._enable_artificial_wind:
self._artificial_force = self.np_random.uniform(low=-0.1, high=0.1)
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
self.data.qpos[:7] = np.array([0., 0., 0., 1.5, 0., 0., 1.5])
self.data.qvel[:7] = np.zeros(7)
mujoco.mj_forward(self.model, self.data)
if self._enable_wind:
self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1)
self.model.opt.wind[:3] = self._wind_vel
self._hit_ball = False
self._ball_land_on_table = False
self._ball_contact_after_hit = False
@ -208,10 +191,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(),
#self.data.joint("tar_x").qvel.copy(),
#self.data.joint("tar_y").qvel.copy(),
#self.data.joint("tar_z").qvel.copy(),
# self.data.body("target_ball").xvel.copy(),
# self.data.joint("tar_x").qvel.copy(),
# self.data.joint("tar_y").qvel.copy(),
# self.data.joint("tar_z").qvel.copy(),
self._goal_pos.copy(),
])
return obs
@ -252,56 +234,54 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
return init_ball_state
def plot_ball_traj(x_traj, y_traj, z_traj):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(x_traj, y_traj, z_traj)
plt.show()
def _get_traj_invalid_reward(self, action, pos_traj, vel_traj):
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
violate_high_bound_error + violate_low_bound_error
return -invalid_penalty
def plot_ball_traj_2d(x_traj, y_traj):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(x_traj, y_traj)
plt.show()
def get_invalid_traj_step_return(self, action, pos_traj, vel_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_reward(action, pos_traj, vel_traj)
return obs, penalty, True, {
"hit_ball": [False],
"ball_return_success": [False],
"land_dist_error": [False],
"trajectory_length": 1,
"num_steps": [1],
}
def plot_compare_trajs(traj1, traj2, title):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(traj1, color='r', label='traj1')
ax.plot(traj2, color='b', label='traj2')
ax.set_title(title)
plt.legend()
plt.show()
@staticmethod
def check_traj_validity(action, pos_traj, vel_traj):
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
return False, pos_traj, vel_traj
return True, pos_traj, vel_traj
if __name__ == "__main__":
env_air = TableTennisEnv(enable_air=False, enable_wind=False, enable_artifical_wind=True)
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False)
for _ in range(10):
obs1 = env_air.reset()
obs2 = env_no_air.reset()
# obs2 = env_with_air.reset()
air_x_pos = []
no_air_x_pos = []
# y_pos = []
# z_pos = []
# x_vel = []
# y_vel = []
# z_vel = []
for _ in range(2000):
env_air.render("human")
obs1, reward1, done1, info1 = env_air.step(np.zeros(7))
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7))
# # _, _, _, _ = env_no_air.step(np.zeros(7))
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0])
# # z_pos.append(env.data.joint("tar_z").qpos[0])
# # x_vel.append(env.data.joint("tar_x").qvel[0])
# # y_vel.append(env.data.joint("tar_y").qvel[0])
# # z_vel.append(env.data.joint("tar_z").qvel[0])
# # print(reward)
if info1["num_steps"] == 150:
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out wind")
break
class TableTennisWind(TableTennisEnv):
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
def _get_obs(self):
obs = np.concatenate([
self.data.qpos.flat[:7].copy(),
self.data.qvel.flat[:7].copy(),
self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(),
self.data.joint("tar_x").qvel.copy(),
self.data.joint("tar_y").qvel.copy(),
self.data.joint("tar_z").qvel.copy(),
self._goal_pos.copy(),
])
return obs
class TableTennisGoalSwitching(TableTennisEnv):
def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99):
super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step)

View File

@ -155,7 +155,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
if __name__ == '__main__':
render = True
render = False
# DMP
# example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render)
@ -163,10 +163,15 @@ if __name__ == '__main__':
# example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render)
# example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
# example_mp("TableTennis4DProMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisWindProMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisGoalSwitchingProMP-v0", seed=10, iterations=20, render=render)
# ProDMP
# example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=20, render=render)
example_mp("TableTennis4DProDMP-v0", seed=10, iterations=2000, render=render)
# example_mp("TableTennisWindProDMP-v0", seed=10, iterations=20, render=render)
# example_mp("TableTennisGoalSwitchingProDMP-v0", seed=10, iterations=20, render=render)
# Altered basis functions
# obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)

View File

@ -168,11 +168,11 @@ def make_bb(
# set tau bounds to minimum of two env steps otherwise computing the velocity is not possible.
# maximum is full duration of one episode.
if phase_kwargs.get('learn_tau'):
if phase_kwargs.get('learn_tau') and phase_kwargs.get('tau_bound') is None:
phase_kwargs["tau_bound"] = [env.dt * 2, black_box_kwargs['duration']]
# Max delay is full duration minus two steps due to above reason
if phase_kwargs.get('learn_delay'):
if phase_kwargs.get('learn_delay') and phase_kwargs.get('delay_bound') is None:
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
phase_gen = get_phase_generator(**phase_kwargs)