dppo/env/plot_traj.py
2024-09-03 21:03:27 -04:00

170 lines
4.8 KiB
Python

"""
Plotting D3IL trajectories
"""
import matplotlib.pyplot as plt
import numpy as np
import os
from functools import partial
class TrajPlotter:
def __init__(self, env_type, **kwargs):
if env_type == "toy":
self.save_traj = save_toy_traj
elif env_type == "avoid":
self.save_traj = partial(save_avoid_traj, **kwargs)
else:
self.save_traj = dummy
def __call__(self, **kwargs):
self.save_traj(**kwargs)
def dummy(*args, **kwargs):
pass
def save_avoid_traj(
obs_full_trajs,
n_render,
max_episode_steps,
render_dir,
itr,
normalization_path,
):
normalization = np.load(normalization_path)
obs_min = normalization["obs_min"]
obs_max = normalization["obs_max"]
# action_min = normalization['action_min']
# action_max = normalization['action_max']
def unnormalize_obs(obs):
obs = (obs + 1) / 2 # [-1, 1] -> [0, 1]
return obs * (obs_max - obs_min) + obs_min
def get_obj_xy_list():
mid_pos = 0.5
offset = 0.075
first_level_y = -0.1
level_distance = 0.18
return [
[mid_pos, first_level_y],
[mid_pos - offset, first_level_y + level_distance],
[mid_pos + offset, first_level_y + level_distance],
[mid_pos - 2 * offset, first_level_y + 2 * level_distance],
[mid_pos, first_level_y + 2 * level_distance],
[mid_pos + 2 * offset, first_level_y + 2 * level_distance],
]
pillar_xys = get_obj_xy_list()
chosen_i = np.random.choice(
range(obs_full_trajs.shape[1]),
n_render,
replace=False,
)
fig = plt.figure()
for i in chosen_i:
obs_traj_env = obs_full_trajs[:max_episode_steps, i, :]
obs_traj_env = unnormalize_obs(obs_traj_env)
# bnds = np.array([[0, 8], [-3, 3]]) # denormalize
# obs_traj_env = obs_traj_env * (bnds[:, 1] - bnds[:, 0]) + bnds[:, 0]
# for j in range(len(obs_traj_env) - 4, len(obs_traj_env)):
for j in range(len(obs_traj_env)):
plt.scatter(
obs_traj_env[j, 0],
obs_traj_env[j, 1],
marker="o",
s=2,
# s=0.2,
# c=plt.cm.Blues(1 - j / 50 + 0.1),
color=(0.3, 0.3, 0.3),
)
if j > 0: # connect
plt.plot(
[obs_traj_env[j - 1, 0], obs_traj_env[j, 0]],
[obs_traj_env[j - 1, 1], obs_traj_env[j, 1]],
color=(0.3, 0.3, 0.3),
)
# finish line
plt.axhline(y=0.4, color=np.array([31, 119, 180]) / 255, linestyle="-")
for xy in pillar_xys:
circle = plt.Circle(xy, 0.01, color=(0.0, 0.0, 0.0), fill=True)
plt.gca().add_patch(circle)
plt.xlabel("X pos")
plt.ylabel("Y pos")
plt.xlim([0.2, 0.8])
plt.ylim([-0.3, 0.5])
ax = plt.gca()
ax.set_aspect("equal", adjustable="box")
ax.set_facecolor("white")
plt.savefig(os.path.join(render_dir, f"traj-{itr}.png"))
plt.close(fig)
def save_toy_traj(
obs_full_trajs,
n_render,
max_episode_steps,
render_dir,
itr,
):
chosen_i = np.random.choice(
range(obs_full_trajs.shape[1]),
n_render,
replace=False,
)
for i in chosen_i:
obs_traj_env = obs_full_trajs[:max_episode_steps, i, :]
bnds = np.array([[0, 8], [-3, 3]]) # denormalize
obs_traj_env = obs_traj_env * (bnds[:, 1] - bnds[:, 0]) + bnds[:, 0]
fig = plt.figure()
for j in range(max_episode_steps):
plt.scatter(
obs_traj_env[j, 0],
obs_traj_env[j, 1],
marker="o",
s=20,
c=plt.cm.Blues(1 - j / 50 + 0.1),
)
if j > 0: # connect
plt.plot(
[obs_traj_env[j - 1, 0], obs_traj_env[j, 0]],
[obs_traj_env[j - 1, 1], obs_traj_env[j, 1]],
"k-",
)
plt.scatter(
obs_traj_env[0, 0],
obs_traj_env[0, 1],
marker="*",
s=100,
c="g",
)
plt.scatter(6, 0, marker="*", s=100, c="r") # target
circle = plt.Circle((3, 0), 1, color="r", fill=True)
plt.gca().add_patch(circle)
plt.plot(
[
bnds[0, 0],
bnds[0, 1],
bnds[0, 1],
bnds[0, 0],
bnds[0, 0],
],
[
bnds[1, 0],
bnds[1, 0],
bnds[1, 1],
bnds[1, 1],
bnds[1, 0],
],
"k-",
)
plt.savefig(os.path.join(render_dir, f"traj-{itr}-{i}.png"))
plt.close(fig)