170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
"""
|
|
Plotting D3IL trajectories
|
|
|
|
"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import os
|
|
from functools import partial
|
|
|
|
|
|
class TrajPlotter:
|
|
|
|
def __init__(self, env_type, **kwargs):
|
|
if env_type == "toy":
|
|
self.save_traj = save_toy_traj
|
|
elif env_type == "avoid":
|
|
self.save_traj = partial(save_avoid_traj, **kwargs)
|
|
else:
|
|
self.save_traj = dummy
|
|
|
|
def __call__(self, **kwargs):
|
|
self.save_traj(**kwargs)
|
|
|
|
|
|
def dummy(*args, **kwargs):
|
|
pass
|
|
|
|
|
|
def save_avoid_traj(
|
|
obs_full_trajs,
|
|
n_render,
|
|
max_episode_steps,
|
|
render_dir,
|
|
itr,
|
|
normalization_path,
|
|
):
|
|
normalization = np.load(normalization_path)
|
|
obs_min = normalization["obs_min"]
|
|
obs_max = normalization["obs_max"]
|
|
|
|
# action_min = normalization['action_min']
|
|
# action_max = normalization['action_max']
|
|
def unnormalize_obs(obs):
|
|
obs = (obs + 1) / 2 # [-1, 1] -> [0, 1]
|
|
return obs * (obs_max - obs_min) + obs_min
|
|
|
|
def get_obj_xy_list():
|
|
mid_pos = 0.5
|
|
offset = 0.075
|
|
first_level_y = -0.1
|
|
level_distance = 0.18
|
|
return [
|
|
[mid_pos, first_level_y],
|
|
[mid_pos - offset, first_level_y + level_distance],
|
|
[mid_pos + offset, first_level_y + level_distance],
|
|
[mid_pos - 2 * offset, first_level_y + 2 * level_distance],
|
|
[mid_pos, first_level_y + 2 * level_distance],
|
|
[mid_pos + 2 * offset, first_level_y + 2 * level_distance],
|
|
]
|
|
|
|
pillar_xys = get_obj_xy_list()
|
|
chosen_i = np.random.choice(
|
|
range(obs_full_trajs.shape[1]),
|
|
n_render,
|
|
replace=False,
|
|
)
|
|
fig = plt.figure()
|
|
for i in chosen_i:
|
|
obs_traj_env = obs_full_trajs[:max_episode_steps, i, :]
|
|
obs_traj_env = unnormalize_obs(obs_traj_env)
|
|
|
|
# bnds = np.array([[0, 8], [-3, 3]]) # denormalize
|
|
# obs_traj_env = obs_traj_env * (bnds[:, 1] - bnds[:, 0]) + bnds[:, 0]
|
|
# for j in range(len(obs_traj_env) - 4, len(obs_traj_env)):
|
|
for j in range(len(obs_traj_env)):
|
|
plt.scatter(
|
|
obs_traj_env[j, 0],
|
|
obs_traj_env[j, 1],
|
|
marker="o",
|
|
s=2,
|
|
# s=0.2,
|
|
# c=plt.cm.Blues(1 - j / 50 + 0.1),
|
|
color=(0.3, 0.3, 0.3),
|
|
)
|
|
if j > 0: # connect
|
|
plt.plot(
|
|
[obs_traj_env[j - 1, 0], obs_traj_env[j, 0]],
|
|
[obs_traj_env[j - 1, 1], obs_traj_env[j, 1]],
|
|
color=(0.3, 0.3, 0.3),
|
|
)
|
|
# finish line
|
|
plt.axhline(y=0.4, color=np.array([31, 119, 180]) / 255, linestyle="-")
|
|
for xy in pillar_xys:
|
|
circle = plt.Circle(xy, 0.01, color=(0.0, 0.0, 0.0), fill=True)
|
|
plt.gca().add_patch(circle)
|
|
|
|
plt.xlabel("X pos")
|
|
plt.ylabel("Y pos")
|
|
|
|
plt.xlim([0.2, 0.8])
|
|
plt.ylim([-0.3, 0.5])
|
|
ax = plt.gca()
|
|
ax.set_aspect("equal", adjustable="box")
|
|
ax.set_facecolor("white")
|
|
plt.savefig(os.path.join(render_dir, f"traj-{itr}.png"))
|
|
plt.close(fig)
|
|
|
|
|
|
def save_toy_traj(
|
|
obs_full_trajs,
|
|
n_render,
|
|
max_episode_steps,
|
|
render_dir,
|
|
itr,
|
|
):
|
|
chosen_i = np.random.choice(
|
|
range(obs_full_trajs.shape[1]),
|
|
n_render,
|
|
replace=False,
|
|
)
|
|
for i in chosen_i:
|
|
obs_traj_env = obs_full_trajs[:max_episode_steps, i, :]
|
|
bnds = np.array([[0, 8], [-3, 3]]) # denormalize
|
|
obs_traj_env = obs_traj_env * (bnds[:, 1] - bnds[:, 0]) + bnds[:, 0]
|
|
fig = plt.figure()
|
|
for j in range(max_episode_steps):
|
|
plt.scatter(
|
|
obs_traj_env[j, 0],
|
|
obs_traj_env[j, 1],
|
|
marker="o",
|
|
s=20,
|
|
c=plt.cm.Blues(1 - j / 50 + 0.1),
|
|
)
|
|
if j > 0: # connect
|
|
plt.plot(
|
|
[obs_traj_env[j - 1, 0], obs_traj_env[j, 0]],
|
|
[obs_traj_env[j - 1, 1], obs_traj_env[j, 1]],
|
|
"k-",
|
|
)
|
|
plt.scatter(
|
|
obs_traj_env[0, 0],
|
|
obs_traj_env[0, 1],
|
|
marker="*",
|
|
s=100,
|
|
c="g",
|
|
)
|
|
plt.scatter(6, 0, marker="*", s=100, c="r") # target
|
|
circle = plt.Circle((3, 0), 1, color="r", fill=True)
|
|
plt.gca().add_patch(circle)
|
|
plt.plot(
|
|
[
|
|
bnds[0, 0],
|
|
bnds[0, 1],
|
|
bnds[0, 1],
|
|
bnds[0, 0],
|
|
bnds[0, 0],
|
|
],
|
|
[
|
|
bnds[1, 0],
|
|
bnds[1, 0],
|
|
bnds[1, 1],
|
|
bnds[1, 1],
|
|
bnds[1, 0],
|
|
],
|
|
"k-",
|
|
)
|
|
plt.savefig(os.path.join(render_dir, f"traj-{itr}-{i}.png"))
|
|
plt.close(fig)
|