goal switching

This commit is contained in:
Hongyi Zhou 2022-11-16 19:45:58 +01:00
parent f9c0c1f3ab
commit d4e844ac45
5 changed files with 27 additions and 23 deletions

View File

@ -163,15 +163,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
# time_valid = self.env.check_time_validity(action)
# time_is_valid = self.env.check_time_validity(action)
#
# if time_valid:
## tricky part, only use weights basis
# basis_weights = action.reshape(7, -1)
# goal_weights = np.zeros((7, 1))
# action = np.concatenate((basis_weights, goal_weights), axis=1).flatten()
# TODO remove this part, right now only needed for beer pong
# mp_params, env_spec_params, proceed = self.env.episode_callback(action, self.traj_gen)
position, velocity = self.get_trajectory(action)
@ -253,9 +248,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
else:
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity)
return self.observation(obs), trajectory_return, done, infos
# else:
# obs, trajectory_return, done, infos = self.env.time_invalid_traj_callback(action)
# return self.observation(obs), trajectory_return, done, infos
def render(self, **kwargs):
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""

View File

@ -260,7 +260,7 @@ for ctxt_dim in [2, 4]:
"ctxt_dim": ctxt_dim,
'frame_skip': 4,
'enable_wind': False,
'enable_switching_goal': False,
'enable_switching_goal': True,
}
)

View File

@ -16,7 +16,7 @@ class MPWrapper(RawInterfaceWrapper):
[False] * 7, # joints velocity
[True] * 2, # position ball x, y
[False] * 1, # position ball z
[True] * 3, # velocity ball x, y, z
[False] * 3, # velocity ball x, y, z
[True] * 2, # target landing position
# [True] * 1, # time
])
@ -33,7 +33,7 @@ class MPWrapper(RawInterfaceWrapper):
return action[0] <= tau_bound[1] and action[0] >= tau_bound[0] \
and action[1] <= delay_bound[1] and action[1] >= delay_bound[0]
def time_invalid_traj_callback(self, action) \
def time_invalid_traj_callback(self, action, pos_traj, vel_traj) \
-> Tuple[np.ndarray, float, bool, dict]:
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))

View File

@ -13,6 +13,8 @@ MAX_EPISODE_STEPS_TABLE_TENNIS = 250
CONTEXT_BOUNDS_2DIMS = np.array([[-1.0, -0.65], [-0.2, 0.65]])
CONTEXT_BOUNDS_4DIMS = np.array([[-1.0, -0.65, -1.0, -0.65],
[-0.2, 0.65, -0.2, 0.65]])
CONTEXT_BOUNDS_SWICHING = np.array([[-1.0, -0.65, -1.0, 0.1],
[-0.2, 0.65, -0.2, 0.65]])
class TableTennisEnv(MujocoEnv, utils.EzPickle):
@ -20,9 +22,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
7 DoF table tennis environment
"""
def __init__(self, ctxt_dim: int = 2, frame_skip: int = 4,
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
enable_switching_goal: bool = False,
enable_wind: bool = False, enable_magnus: bool = False):
enable_wind: bool = False, enable_magnus: bool = False,
enable_air: bool = False):
utils.EzPickle.__init__(**locals())
self._steps = 0
@ -53,14 +56,18 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.context_bounds = CONTEXT_BOUNDS_2DIMS
elif ctxt_dim == 4:
self.context_bounds = CONTEXT_BOUNDS_4DIMS
if self._enable_goal_switching:
self.context_bounds = CONTEXT_BOUNDS_SWICHING
else:
raise NotImplementedError
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
# complex dynamics settings
# self.model.opt.density = 1.225
# self.model.opt.viscosity = 2.27e-5
if enable_air:
self.model.opt.density = 1.225
self.model.opt.viscosity = 2.27e-5
self._enable_wind = enable_wind
self._enable_magnus = enable_magnus
self._wind_vel = np.zeros(3)
@ -244,17 +251,20 @@ def plot_ball_traj_2d(x_traj, y_traj):
ax.plot(x_traj, y_traj)
plt.show()
def plot_single_axis(traj):
def plot_single_axis(traj, title):
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(traj)
ax.set_title(title)
plt.show()
if __name__ == "__main__":
env = TableTennisEnv(enable_wind=True)
for _ in range(5):
obs = env.reset()
env = TableTennisEnv(enable_air=True)
# env_with_air = TableTennisEnv(enable_air=True)
for _ in range(1):
obs1 = env.reset()
# obs2 = env_with_air.reset()
x_pos = []
y_pos = []
z_pos = []
@ -262,8 +272,8 @@ if __name__ == "__main__":
y_vel = []
z_vel = []
for _ in range(2000):
# env.render("human")
obs, reward, done, info = env.step(np.zeros(7))
# _, _, _, _ = env_no_air.step(np.zeros(7))
x_pos.append(env.data.joint("tar_x").qpos[0])
y_pos.append(env.data.joint("tar_y").qpos[0])
z_pos.append(env.data.joint("tar_z").qpos[0])
@ -272,6 +282,6 @@ if __name__ == "__main__":
z_vel.append(env.data.joint("tar_z").qvel[0])
# print(reward)
if done:
plot_ball_traj_2d(x_pos, y_pos)
plot_single_axis(x_vel)
# plot_ball_traj_2d(x_pos, y_pos)
plot_single_axis(x_pos, title="x_vel without air")
break

View File

@ -45,6 +45,7 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
# This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
# full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
# to the return of a trajectory. Default is the sum over the step-wise rewards.
print(f'target obs: {obs[-3:]}')
obs, reward, done, info = env.step(ac)
print(f'steps: {info["num_steps"][-1]}')
# Aggregated returns