4d table tennis
This commit is contained in:
parent
5a547d85f9
commit
6193f87fe7
@ -14,7 +14,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # joints position
|
[False] * 7, # joints position
|
||||||
[False] * 7, # joints velocity
|
[False] * 7, # joints velocity
|
||||||
[False] * 3, # position ball
|
[True] * 2, # position ball x, y
|
||||||
|
[False] * 1, # position ball z
|
||||||
[True] * 2, # target landing position
|
[True] * 2, # target landing position
|
||||||
# [True] * 1, # time
|
# [True] * 1, # time
|
||||||
])
|
])
|
||||||
@ -36,10 +37,10 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||||
-> Tuple[np.ndarray, float, bool, dict]:
|
-> Tuple[np.ndarray, float, bool, dict]:
|
||||||
tau_invalid_penalty = np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])
|
tau_invalid_penalty = 0.3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||||
delay_invalid_penalty = np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])
|
delay_invalid_penalty = 0.3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||||
violate_high_bound_error = np.sum(np.maximum(pos_traj - jnt_pos_high, 0))
|
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||||
violate_low_bound_error = np.sum(np.maximum(jnt_pos_low - pos_traj, 0))
|
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||||
violate_high_bound_error + violate_low_bound_error
|
violate_high_bound_error + violate_low_bound_error
|
||||||
return self.get_obs(), -invalid_penalty, True, {
|
return self.get_obs(), -invalid_penalty, True, {
|
||||||
|
@ -127,7 +127,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
||||||
self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
|
self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
|
||||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
@ -188,8 +188,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
|
x_pos, y_pos, z_pos = -0.5, 0.35, 1.75
|
||||||
x_vel, y_vel, z_vel = 2.5, 0., 0.5
|
x_vel, y_vel, z_vel = 2.5, 0., 0.5
|
||||||
if random_pos:
|
if random_pos:
|
||||||
x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0], size=1)
|
x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0])
|
||||||
y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1], size=1)
|
y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1])
|
||||||
if random_vel:
|
if random_vel:
|
||||||
x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1)
|
x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1)
|
||||||
init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel])
|
init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel])
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
|
|
||||||
jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||||
jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||||
delay_bound = [0.05, 0.3]
|
delay_bound = [0.05, 0.2]
|
||||||
tau_bound = [0.5, 1.5]
|
tau_bound = [0.5, 1.5]
|
||||||
|
|
||||||
net_height = 0.1
|
net_height = 0.1
|
||||||
|
@ -17,9 +17,12 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
|||||||
# It takes care of seeding and enables the use of a variety of external environments using the gym interface.
|
# It takes care of seeding and enables the use of a variety of external environments using the gym interface.
|
||||||
env = fancy_gym.make(env_name, seed)
|
env = fancy_gym.make(env_name, seed)
|
||||||
|
|
||||||
|
# env.traj_gen.basis_gn.show_basis(plot=True)
|
||||||
|
|
||||||
returns = 0
|
returns = 0
|
||||||
# env.render(mode=None)
|
# env.render(mode=None)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
print(obs)
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
@ -46,8 +49,9 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
|||||||
returns += reward
|
returns += reward
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
print(reward)
|
# print(reward)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
print(obs)
|
||||||
|
|
||||||
|
|
||||||
def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render=True):
|
def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render=True):
|
||||||
|
Loading…
Reference in New Issue
Block a user