random sampling for goal switching & adjust height for initial ball state
This commit is contained in:
parent
f47f00a292
commit
96f17e02cf
@ -260,7 +260,8 @@ for ctxt_dim in [2, 4]:
|
|||||||
"ctxt_dim": ctxt_dim,
|
"ctxt_dim": ctxt_dim,
|
||||||
'frame_skip': 4,
|
'frame_skip': 4,
|
||||||
'enable_wind': False,
|
'enable_wind': False,
|
||||||
'enable_switching_goal': False,
|
'enable_switching_goal': True,
|
||||||
|
'enable_air': False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -87,8 +87,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
unstable_simulation = False
|
unstable_simulation = False
|
||||||
|
|
||||||
if self._enable_goal_switching:
|
if self._enable_goal_switching:
|
||||||
if self._steps == 45 and self.np_random.uniform(0, 1) < 0.5:
|
if self._steps == 99 and self.np_random.uniform(0, 1) < 0.5:
|
||||||
self._goal_pos[1] = -self._goal_pos[1]
|
new_goal_pos = self._generate_goal_pos(random=True)
|
||||||
|
new_goal_pos[1] = -new_goal_pos[1]
|
||||||
|
self._goal_pos = new_goal_pos
|
||||||
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
|
||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
@ -151,8 +153,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=False, random_vel=False)
|
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
||||||
self._goal_pos = self._generate_goal_pos(random=False)
|
self._init_ball_state[2] = 1.85
|
||||||
|
self._goal_pos = self._generate_goal_pos(random=True)
|
||||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||||
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
||||||
@ -167,7 +170,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
if self._enable_wind:
|
if self._enable_wind:
|
||||||
self._wind_vel[1] = self.np_random.uniform(low=-5, high=5, size=1)
|
self._wind_vel[1] = self.np_random.uniform(low=-10, high=10, size=1)
|
||||||
self.model.opt.wind[:3] = self._wind_vel
|
self.model.opt.wind[:3] = self._wind_vel
|
||||||
|
|
||||||
self._hit_ball = False
|
self._hit_ball = False
|
||||||
@ -251,37 +254,43 @@ def plot_ball_traj_2d(x_traj, y_traj):
|
|||||||
ax.plot(x_traj, y_traj)
|
ax.plot(x_traj, y_traj)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def plot_single_axis(traj, title):
|
def plot_compare_trajs(traj1, traj2, title):
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_subplot(111)
|
ax = fig.add_subplot(111)
|
||||||
ax.plot(traj)
|
ax.plot(traj1, color='r', label='traj1')
|
||||||
|
ax.plot(traj2, color='b', label='traj2')
|
||||||
ax.set_title(title)
|
ax.set_title(title)
|
||||||
|
plt.legend()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = TableTennisEnv(enable_air=True)
|
env_air = TableTennisEnv(enable_air=True, enable_wind=False)
|
||||||
# env_with_air = TableTennisEnv(enable_air=True)
|
env_no_air = TableTennisEnv(enable_air=False, enable_wind=False)
|
||||||
for _ in range(1):
|
for _ in range(10):
|
||||||
obs1 = env.reset()
|
obs1 = env_air.reset()
|
||||||
|
obs2 = env_no_air.reset()
|
||||||
# obs2 = env_with_air.reset()
|
# obs2 = env_with_air.reset()
|
||||||
x_pos = []
|
air_x_pos = []
|
||||||
y_pos = []
|
no_air_x_pos = []
|
||||||
z_pos = []
|
# y_pos = []
|
||||||
x_vel = []
|
# z_pos = []
|
||||||
y_vel = []
|
# x_vel = []
|
||||||
z_vel = []
|
# y_vel = []
|
||||||
|
# z_vel = []
|
||||||
for _ in range(2000):
|
for _ in range(2000):
|
||||||
obs, reward, done, info = env.step(np.zeros(7))
|
# env_air.render("human")
|
||||||
|
obs1, reward1, done1, info1 = env_air.step(np.zeros(7))
|
||||||
|
obs2, reward2, done2, info2 = env_no_air.step(np.zeros(7))
|
||||||
# _, _, _, _ = env_no_air.step(np.zeros(7))
|
# _, _, _, _ = env_no_air.step(np.zeros(7))
|
||||||
x_pos.append(env.data.joint("tar_x").qpos[0])
|
air_x_pos.append(env_air.data.joint("tar_z").qpos[0])
|
||||||
y_pos.append(env.data.joint("tar_y").qpos[0])
|
no_air_x_pos.append(env_no_air.data.joint("tar_z").qpos[0])
|
||||||
z_pos.append(env.data.joint("tar_z").qpos[0])
|
# z_pos.append(env.data.joint("tar_z").qpos[0])
|
||||||
x_vel.append(env.data.joint("tar_x").qvel[0])
|
# x_vel.append(env.data.joint("tar_x").qvel[0])
|
||||||
y_vel.append(env.data.joint("tar_y").qvel[0])
|
# y_vel.append(env.data.joint("tar_y").qvel[0])
|
||||||
z_vel.append(env.data.joint("tar_z").qvel[0])
|
# z_vel.append(env.data.joint("tar_z").qvel[0])
|
||||||
# print(reward)
|
# print(reward)
|
||||||
if done:
|
if info1["num_steps"] == 150:
|
||||||
# plot_ball_traj_2d(x_pos, y_pos)
|
# plot_ball_traj_2d(x_pos, y_pos)
|
||||||
plot_single_axis(x_pos, title="x_vel without air")
|
plot_compare_trajs(air_x_pos, no_air_x_pos, title="z_pos with/out air")
|
||||||
break
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user