goal switching

This commit is contained in:
Hongyi Zhou 2022-11-16 19:45:58 +01:00
parent f9c0c1f3ab
commit d4e844ac45
5 changed files with 27 additions and 23 deletions

View File

@ -163,15 +163,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step""" """ This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
# time_valid = self.env.check_time_validity(action) # time_is_valid = self.env.check_time_validity(action)
# #
# if time_valid: # if time_valid:
## tricky part, only use weights basis
# basis_weights = action.reshape(7, -1)
# goal_weights = np.zeros((7, 1))
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong # TODO remove this part, right now only needed for beer pong
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen) # mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(action) position, velocity = self.get_trajectory(action)
@ -253,9 +248,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
else: else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity) obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos return self.observation(obs), trajectory_return, done, infos
# else:
# obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
# return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs): def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout. """Only set render options here, such that they can be used during the rollout.
This only needs to be called once""" This only needs to be called once"""

View File

@ -260,7 +260,7 @@ for ctxt_dim in [2, 4]:
"ctxt_dim": ctxt_dim, "ctxt_dim": ctxt_dim,
'frame_skip': 4, 'frame_skip': 4,
'enable_wind': False, 'enable_wind': False,
'enable_switching_goal': False, 'enable_switching_goal': True,
} }
) )

View File

@ -16,7 +16,7 @@ class MPWrapper(RawInterfaceWrapper):
[False] * 7, # joints velocity [False] * 7, # joints velocity
[True] * 2, # position ball x, y [True] * 2, # position ball x, y
[False] * 1, # position ball z [False] * 1, # position ball z
[True] * 3, # velocity ball x, y, z [False] * 3, # velocity ball x, y, z
[True] * 2, # target landing position [True] * 2, # target landing position
# [True] * 1, # time # [True] * 1, # time
]) ])
@ -33,7 +33,7 @@ class MPWrapper(RawInterfaceWrapper):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \ return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0] and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def time_invalid_traj_callback(self, action) \ def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
-> Tuple[np.ndarray, float, bool, dict]: -> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))

View File

@ -13,6 +13,8 @@ MAX_EPISODE_STEPS_TABLE_TENNIS = 250
CONTEXT_BOUNDS_2DIMS = np.array([[-1.0, -0.65], [-0.2, 0.65]]) CONTEXT_BOUNDS_2DIMS = np.array([[-1.0, -0.65], [-0.2, 0.65]])
CONTEXT_BOUNDS_4DIMS = np.array([[-1.0, -0.65, -1.0, -0.65], CONTEXT_BOUNDS_4DIMS = np.array([[-1.0, -0.65, -1.0, -0.65],
[-0.2, 0.65, -0.2, 0.65]]) [-0.2, 0.65, -0.2, 0.65]])
CONTEXT_BOUNDS_SWICHING = np.array([[-1.0, -0.65, -1.0, 0.1],
[-0.2, 0.65, -0.2, 0.65]])
class TableTennisEnv(MujocoEnv, utils.EzPickle): class TableTennisEnv(MujocoEnv, utils.EzPickle):
@ -20,9 +22,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
7 DoF table tennis environment 7 DoF table tennis environment
""" """
def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4, def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
enable_switching_goal: bool = False, enable_switching_goal: bool = False,
enable_wind: bool = False, enable_magnus: bool = False): enable_wind: bool = False, enable_magnus: bool = False,
enable_air: bool = False):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
@ -53,14 +56,18 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4: elif ctxt_dim == 4:
self.context_bounds = CONTEXT_BOUNDS_4DIMS self.context_bounds = CONTEXT_BOUNDS_4DIMS
if self._enable_goal_switching:
self.context_bounds = CONTEXT_BOUNDS_SWICHING
else: else:
raise NotImplementedError raise NotImplementedError
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
# complex dynamics settings # complex dynamics settings
# self.model.opt.density = 1.225 if enable_air:
# self.model.opt.viscosity = 2.27e-5 self.model.opt.density = 1.225
self.model.opt.viscosity = 2.27e-5
self._enable_wind = enable_wind self._enable_wind = enable_wind
self._enable_magnus = enable_magnus self._enable_magnus = enable_magnus
self._wind_vel = np.zeros(3) self._wind_vel = np.zeros(3)
@ -244,17 +251,20 @@ def plot_ball_traj_2d(x_traj, y_traj):
ax.plot(x_traj, y_traj) ax.plot(x_traj, y_traj)
plt.show() plt.show()
def plot_single_axis(traj): def plot_single_axis(traj, title):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
ax.plot(traj) ax.plot(traj)
ax.set_title(title)
plt.show() plt.show()
if __name__ == "__main__": if __name__ == "__main__":
env = TableTennisEnv(enable_wind=True) env = TableTennisEnv(enable_air=True)
for _ in range(5): # env_with_air = TableTennisEnv(enable_air=True)
obs = env.reset() for _ in range(1):
obs1 = env.reset()
# obs2 = env_with_air.reset()
x_pos = [] x_pos = []
y_pos = [] y_pos = []
z_pos = [] z_pos = []
@ -262,8 +272,8 @@ if __name__ == "__main__":
y_vel = [] y_vel = []
z_vel = [] z_vel = []
for _ in range(2000): for _ in range(2000):
# env.render("human")
obs, reward, done, info = env.step(np.zeros(7)) obs, reward, done, info = env.step(np.zeros(7))
# _, _, _, _ = env_no_air.step(np.zeros(7))
x_pos.append(env.data.joint("tar_x").qpos[0]) x_pos.append(env.data.joint("tar_x").qpos[0])
y_pos.append(env.data.joint("tar_y").qpos[0]) y_pos.append(env.data.joint("tar_y").qpos[0])
z_pos.append(env.data.joint("tar_z").qpos[0]) z_pos.append(env.data.joint("tar_z").qpos[0])
@ -272,6 +282,6 @@ if __name__ == "__main__":
z_vel.append(env.data.joint("tar_z").qvel[0]) z_vel.append(env.data.joint("tar_z").qvel[0])
# print(reward) # print(reward)
if done: if done:
plot_ball_traj_2d(x_pos, y_pos) # plot_ball_traj_2d(x_pos, y_pos)
plot_single_axis(x_vel) plot_single_axis(x_pos, title="x_vel without air")
break break

View File

@ -45,6 +45,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
# This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the # This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
# full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal # full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
# to the return of a trajectory. Default is the sum over the step-wise rewards. # to the return of a trajectory. Default is the sum over the step-wise rewards.
print(f'target obs: {obs[-3:]}')
obs, reward, done, info = env.step(ac) obs, reward, done, info = env.step(ac)
print(f'steps: {info["num_steps"][-1]}') print(f'steps: {info["num_steps"][-1]}')
# Aggregated returns # Aggregated returns