295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""
|
|
Process d3il dataset and save it into our custom format so it can be loaded for diffusion training.
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
import pickle
|
|
import random
|
|
import matplotlib.pyplot as plt
|
|
from copy import deepcopy
|
|
|
|
from agent.dataset.d3il_dataset.aligning_dataset import Aligning_Dataset
|
|
from agent.dataset.d3il_dataset.avoiding_dataset import Avoiding_Dataset
|
|
from agent.dataset.d3il_dataset.pushing_dataset import Pushing_Dataset
|
|
from agent.dataset.d3il_dataset.sorting_dataset import Sorting_Dataset
|
|
from agent.dataset.d3il_dataset.stacking_dataset import Stacking_Dataset
|
|
|
|
|
|
def make_dataset(load_path, save_dir, save_name_prefix, env_type, val_split):
|
|
if env_type == "align":
|
|
demo_dataset = Aligning_Dataset(
|
|
load_path,
|
|
action_dim=3,
|
|
obs_dim=20,
|
|
max_len_data=512,
|
|
)
|
|
elif env_type == "avoid":
|
|
demo_dataset = Avoiding_Dataset(
|
|
load_path,
|
|
action_dim=2,
|
|
obs_dim=4,
|
|
max_len_data=200,
|
|
)
|
|
elif env_type == "push":
|
|
demo_dataset = Pushing_Dataset(
|
|
load_path,
|
|
action_dim=2,
|
|
obs_dim=10,
|
|
max_len_data=512,
|
|
)
|
|
elif env_type == "sort":
|
|
# Can config number of boxes to be 2, 4, or 6.
|
|
# TODO: add other numbers of boxes
|
|
demo_dataset = Sorting_Dataset(
|
|
load_path,
|
|
action_dim=2,
|
|
obs_dim=10,
|
|
max_len_data=600,
|
|
num_boxes=2,
|
|
)
|
|
elif env_type == "stack":
|
|
demo_dataset = Stacking_Dataset(
|
|
load_path,
|
|
action_dim=8,
|
|
obs_dim=20,
|
|
max_len_data=1000,
|
|
)
|
|
else:
|
|
raise ValueError("Invalid dataset type.")
|
|
# extract length of each trajectory in the file
|
|
traj_lengths = []
|
|
actions = demo_dataset.actions
|
|
obs = demo_dataset.observations
|
|
masks = demo_dataset.masks
|
|
action_dim = actions.shape[2]
|
|
obs_dim = obs.shape[2]
|
|
for ep in range(masks.shape[0]):
|
|
traj_lengths.append(int(masks[ep].sum().item()))
|
|
traj_lengths = np.array(traj_lengths)
|
|
max_traj_length = np.max(traj_lengths)
|
|
|
|
# split indices in train and val
|
|
num_traj = len(traj_lengths)
|
|
num_train = int(num_traj * (1 - val_split))
|
|
train_indices = random.sample(range(num_traj), k=num_train)
|
|
|
|
# take the max and min of obs and action
|
|
obs_min = np.zeros((obs_dim))
|
|
obs_max = np.zeros((obs_dim))
|
|
action_min = np.zeros((action_dim))
|
|
action_max = np.zeros((action_dim))
|
|
for i in tqdm(range(len(traj_lengths))):
|
|
T = traj_lengths[i]
|
|
obs_traj = obs[i, :T].numpy()
|
|
action_traj = actions[i, :T].numpy()
|
|
obs_min = np.min(np.vstack((obs_min, np.min(obs_traj, axis=0))), axis=0)
|
|
obs_max = np.max(np.vstack((obs_max, np.max(obs_traj, axis=0))), axis=0)
|
|
action_min = np.min(
|
|
np.vstack((action_min, np.min(action_traj, axis=0))), axis=0
|
|
)
|
|
action_max = np.max(
|
|
np.vstack((action_max, np.max(action_traj, axis=0))), axis=0
|
|
)
|
|
logger.info("\n========== Basic Info ===========")
|
|
logger.info("total transitions: {}".format(np.sum(traj_lengths)))
|
|
logger.info("total trajectories: {}".format(len(traj_lengths)))
|
|
logger.info(
|
|
f"traj length mean/std: {np.mean(traj_lengths)}, {np.std(traj_lengths)}"
|
|
)
|
|
logger.info(f"traj length min/max: {np.min(traj_lengths)}, {np.max(traj_lengths)}")
|
|
logger.info(f"obs min: {obs_min}")
|
|
logger.info(f"obs max: {obs_max}")
|
|
logger.info(f"action min: {action_min}")
|
|
logger.info(f"action max: {action_max}")
|
|
|
|
# do over all indices
|
|
out_train = {}
|
|
keys = [
|
|
"observations",
|
|
"actions",
|
|
"rewards",
|
|
]
|
|
total_timesteps = actions.shape[1]
|
|
out_train["observations"] = np.empty((0, total_timesteps, obs_dim))
|
|
out_train["actions"] = np.empty((0, total_timesteps, action_dim))
|
|
out_train["rewards"] = np.empty((0, total_timesteps))
|
|
out_train["traj_length"] = []
|
|
out_val = deepcopy(out_train)
|
|
for i in tqdm(range(len(traj_lengths))):
|
|
if i in train_indices:
|
|
out = out_train
|
|
else:
|
|
out = out_val
|
|
|
|
T = traj_lengths[i]
|
|
obs_traj = obs[i].numpy()
|
|
action_traj = actions[i].numpy()
|
|
|
|
# scale to [-1, 1] for both ob and action
|
|
obs_traj = 2 * (obs_traj - obs_min) / (obs_max - obs_min + 1e-6) - 1
|
|
action_traj = (
|
|
2 * (action_traj - action_min) / (action_max - action_min + 1e-6) - 1
|
|
)
|
|
|
|
# get episode length
|
|
traj_length = T
|
|
out["traj_length"].append(traj_length)
|
|
|
|
# extract
|
|
rewards = np.zeros(total_timesteps) # no reward from d3il dataset
|
|
data_traj = {
|
|
"observations": obs_traj,
|
|
"actions": action_traj,
|
|
"rewards": rewards,
|
|
}
|
|
for key in keys:
|
|
traj = data_traj[key]
|
|
out[key] = np.vstack((out[key], traj[None]))
|
|
|
|
# plot all trajectories and save in a figure
|
|
def plot(out, name):
|
|
def get_obj_xy_list():
|
|
mid_pos = 0.5
|
|
offset = 0.075
|
|
first_level_y = -0.1
|
|
level_distance = 0.18
|
|
return [
|
|
[mid_pos, first_level_y],
|
|
[mid_pos - offset, first_level_y + level_distance],
|
|
[mid_pos + offset, first_level_y + level_distance],
|
|
[mid_pos - 2 * offset, first_level_y + 2 * level_distance],
|
|
[mid_pos, first_level_y + 2 * level_distance],
|
|
[mid_pos + 2 * offset, first_level_y + 2 * level_distance],
|
|
]
|
|
|
|
pillar_xys = get_obj_xy_list()
|
|
fig = plt.figure()
|
|
all_trajs = out["observations"] # num x timestep x obs
|
|
for traj, traj_length in zip(all_trajs, out["traj_length"]):
|
|
# unnormalize
|
|
traj = (traj + 1) / 2 # [-1, 1] -> [0, 1]
|
|
traj = traj * (obs_max - obs_min) + obs_min
|
|
plt.plot(
|
|
traj[:traj_length, 2], traj[:traj_length, 3], color=(0.3, 0.3, 0.3)
|
|
)
|
|
plt.axhline(y=0.4, color=np.array([31, 119, 180]) / 255, linestyle="-")
|
|
for xy in pillar_xys:
|
|
circle = plt.Circle(xy, 0.01, color=(0.0, 0.0, 0.0), fill=True)
|
|
plt.gca().add_patch(circle)
|
|
plt.xlabel("X pos")
|
|
plt.ylabel("Y pos")
|
|
plt.xlim([0.2, 0.8])
|
|
plt.ylim([-0.3, 0.5])
|
|
ax = plt.gca()
|
|
ax.set_aspect("equal", adjustable="box")
|
|
ax.set_facecolor("white")
|
|
plt.savefig(os.path.join(save_dir, name))
|
|
plt.close(fig)
|
|
|
|
plot(out_train, name="train-trajs.png")
|
|
plot(out_val, name="val-trajs.png")
|
|
|
|
# Save to np file
|
|
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
|
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl")
|
|
with open(save_train_path, "wb") as f:
|
|
pickle.dump(out_train, f)
|
|
with open(save_val_path, "wb") as f:
|
|
pickle.dump(out_val, f)
|
|
normalization_save_path = os.path.join(
|
|
save_dir, save_name_prefix + "normalization.npz"
|
|
)
|
|
np.savez(
|
|
normalization_save_path,
|
|
obs_min=obs_min,
|
|
obs_max=obs_max,
|
|
action_min=action_min,
|
|
action_max=action_max,
|
|
)
|
|
|
|
# debug
|
|
logger.info("\n========== Final ===========")
|
|
logger.info(
|
|
f"Train - Number of episodes and transitions: {len(out_train['traj_length'])}, {np.sum(out_train['traj_length'])}"
|
|
)
|
|
logger.info(
|
|
f"Val - Number of episodes and transitions: {len(out_val['traj_length'])}, {np.sum(out_val['traj_length'])}"
|
|
)
|
|
logger.info(
|
|
f"Train - Mean/Std trajectory length: {np.mean(out_train['traj_length'])}, {np.std(out_train['traj_length'])}"
|
|
)
|
|
logger.info(
|
|
f"Train - Max/Min trajectory length: {np.max(out_train['traj_length'])}, {np.min(out_train['traj_length'])}"
|
|
)
|
|
if val_split > 0:
|
|
logger.info(
|
|
f"Val - Mean/Std trajectory length: {np.mean(out_val['traj_length'])}, {np.std(out_val['traj_length'])}"
|
|
)
|
|
logger.info(
|
|
f"Val - Max/Min trajectory length: {np.max(out_val['traj_length'])}, {np.min(out_val['traj_length'])}"
|
|
)
|
|
for obs_dim_ind in range(obs_dim):
|
|
obs = out_train["observations"][:, :, obs_dim_ind]
|
|
logger.info(
|
|
f"Train - Obs dim {obs_dim_ind+1} mean {np.mean(obs)} std {np.std(obs)} min {np.min(obs)} max {np.max(obs)}"
|
|
)
|
|
for action_dim_ind in range(action_dim):
|
|
action = out_train["actions"][:, :, action_dim_ind]
|
|
logger.info(
|
|
f"Train - Action dim {action_dim_ind+1} mean {np.mean(action)} std {np.std(action)} min {np.min(action)} max {np.max(action)}"
|
|
)
|
|
if val_split > 0:
|
|
for obs_dim_ind in range(obs_dim):
|
|
obs = out_val["observations"][:, :, obs_dim_ind]
|
|
logger.info(
|
|
f"Val - Obs dim {obs_dim_ind+1} mean {np.mean(obs)} std {np.std(obs)} min {np.min(obs)} max {np.max(obs)}"
|
|
)
|
|
for action_dim_ind in range(action_dim):
|
|
action = out_val["actions"][:, :, action_dim_ind]
|
|
logger.info(
|
|
f"Val - Action dim {action_dim_ind+1} mean {np.mean(action)} std {np.std(action)} min {np.min(action)} max {np.max(action)}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--load_path", type=str, default=".")
|
|
parser.add_argument("--save_dir", type=str, default=".")
|
|
parser.add_argument("--save_name_prefix", type=str, default="")
|
|
parser.add_argument("--env_type", type=str, default="align")
|
|
parser.add_argument("--val_split", type=float, default="0.2")
|
|
args = parser.parse_args()
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
import logging
|
|
import datetime
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
log_path = os.path.join(
|
|
args.save_dir,
|
|
args.save_name_prefix
|
|
+ f"_{datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.log",
|
|
)
|
|
logger = logging.getLogger("get_D4RL_dataset")
|
|
logger.setLevel(logging.INFO)
|
|
file_handler = logging.FileHandler(log_path)
|
|
file_handler.setLevel(logging.INFO) # Set the minimum level for this handler
|
|
formatter = logging.Formatter(
|
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
file_handler.setFormatter(formatter)
|
|
logger.addHandler(file_handler)
|
|
|
|
make_dataset(
|
|
args.load_path,
|
|
args.save_dir,
|
|
args.save_name_prefix,
|
|
args.env_type,
|
|
args.val_split,
|
|
)
|