From 2ddf63b8f5cac085db705be1c58f81b20c6c4b0e Mon Sep 17 00:00:00 2001 From: allenzren Date: Wed, 11 Sep 2024 21:09:17 -0400 Subject: [PATCH] squash commits --- agent/dataset/sequence.py | 65 +++++--- agent/finetune/train_awr_diffusion_agent.py | 74 +++++---- agent/finetune/train_dipo_diffusion_agent.py | 44 ++++-- agent/finetune/train_dql_diffusion_agent.py | 52 ++++--- agent/finetune/train_idql_diffusion_agent.py | 45 ++++-- agent/finetune/train_ppo_diffusion_agent.py | 68 +++++--- .../finetune/train_ppo_diffusion_img_agent.py | 23 +-- .../train_ppo_exact_diffusion_agent.py | 80 +++++++--- agent/finetune/train_ppo_gaussian_agent.py | 74 ++++++--- .../finetune/train_ppo_gaussian_img_agent.py | 21 +-- agent/finetune/train_qsm_diffusion_agent.py | 40 +++-- agent/finetune/train_rwr_diffusion_agent.py | 56 ++++--- .../avoid_m1/ft_ppo_diffusion_mlp.yaml | 4 +- .../avoid_m1/ft_ppo_gaussian_mlp.yaml | 3 +- .../finetune/avoid_m1/ft_ppo_gmm_mlp.yaml | 3 +- .../avoid_m2/ft_ppo_diffusion_mlp.yaml | 4 +- .../avoid_m2/ft_ppo_gaussian_mlp.yaml | 3 +- .../finetune/avoid_m2/ft_ppo_gmm_mlp.yaml | 3 +- .../avoid_m3/ft_ppo_diffusion_mlp.yaml | 4 +- .../avoid_m3/ft_ppo_gaussian_mlp.yaml | 3 +- .../finetune/avoid_m3/ft_ppo_gmm_mlp.yaml | 3 +- .../pretrain/avoid_m1/pre_diffusion_mlp.yaml | 2 - cfg/d3il/pretrain/avoid_m1/pre_gmm_mlp.yaml | 2 +- .../pretrain/avoid_m2/pre_diffusion_mlp.yaml | 2 - cfg/d3il/pretrain/avoid_m2/pre_gmm_mlp.yaml | 2 +- .../pretrain/avoid_m3/pre_diffusion_mlp.yaml | 2 - cfg/d3il/pretrain/avoid_m3/pre_gmm_mlp.yaml | 2 +- .../lamp_low/ft_ppo_diffusion_mlp.yaml | 4 +- .../lamp_low/ft_ppo_diffusion_unet.yaml | 4 +- .../lamp_low/ft_ppo_gaussian_mlp.yaml | 5 +- .../lamp_med/ft_ppo_diffusion_mlp.yaml | 4 +- .../lamp_med/ft_ppo_diffusion_unet.yaml | 4 +- .../lamp_med/ft_ppo_gaussian_mlp.yaml | 5 +- .../one_leg_low/ft_ppo_diffusion_mlp.yaml | 4 +- .../one_leg_low/ft_ppo_diffusion_unet.yaml | 4 +- .../one_leg_low/ft_ppo_gaussian_mlp.yaml | 5 +- .../one_leg_med/ft_ppo_diffusion_mlp.yaml | 4 +- .../one_leg_med/ft_ppo_diffusion_unet.yaml | 4 +- .../one_leg_med/ft_ppo_gaussian_mlp.yaml | 5 +- .../round_table_low/ft_ppo_diffusion_mlp.yaml | 4 +- .../ft_ppo_diffusion_unet.yaml | 4 +- .../round_table_low/ft_ppo_gaussian_mlp.yaml | 3 +- .../round_table_med/ft_ppo_diffusion_mlp.yaml | 4 +- .../ft_ppo_diffusion_unet.yaml | 4 +- .../round_table_med/ft_ppo_gaussian_mlp.yaml | 5 +- .../pretrain/lamp_low/pre_diffusion_mlp.yaml | 2 - .../pretrain/lamp_low/pre_diffusion_unet.yaml | 2 - .../pretrain/lamp_med/pre_diffusion_mlp.yaml | 2 - .../pretrain/lamp_med/pre_diffusion_unet.yaml | 2 - .../one_leg_low/pre_diffusion_mlp.yaml | 2 - .../one_leg_low/pre_diffusion_unet.yaml | 2 - .../one_leg_med/pre_diffusion_mlp.yaml | 2 - .../one_leg_med/pre_diffusion_unet.yaml | 2 - .../round_table_low/pre_diffusion_mlp.yaml | 2 - .../round_table_low/pre_diffusion_unet.yaml | 2 - .../round_table_med/pre_diffusion_mlp.yaml | 2 - .../round_table_med/pre_diffusion_unet.yaml | 2 - .../halfcheetah-v2/ft_awr_diffusion_mlp.yaml | 5 +- .../halfcheetah-v2/ft_dipo_diffusion_mlp.yaml | 5 +- .../halfcheetah-v2/ft_dql_diffusion_mlp.yaml | 5 +- .../halfcheetah-v2/ft_idql_diffusion_mlp.yaml | 7 +- .../halfcheetah-v2/ft_ppo_diffusion_mlp.yaml | 6 +- .../ft_ppo_exact_diffusion_mlp.yaml | 6 +- .../halfcheetah-v2/ft_qsm_diffusion_mlp.yaml | 5 +- .../halfcheetah-v2/ft_rwr_diffusion_mlp.yaml | 3 +- .../halfcheetah-v2/ppo_diffusion_mlp.yaml | 6 +- .../halfcheetah-v2/ppo_gaussian_mlp.yaml | 5 +- .../hopper-v2/ft_awr_diffusion_mlp.yaml | 5 +- .../hopper-v2/ft_dipo_diffusion_mlp.yaml | 5 +- .../hopper-v2/ft_dql_diffusion_mlp.yaml | 5 +- .../hopper-v2/ft_idql_diffusion_mlp.yaml | 7 +- .../hopper-v2/ft_ppo_diffusion_mlp.yaml | 6 +- .../hopper-v2/ft_ppo_exact_diffusion_mlp.yaml | 6 +- .../hopper-v2/ft_qsm_diffusion_mlp.yaml | 5 +- .../hopper-v2/ft_rwr_diffusion_mlp.yaml | 3 +- .../finetune/hopper-v2/ppo_diffusion_mlp.yaml | 6 +- .../finetune/hopper-v2/ppo_gaussian_mlp.yaml | 5 +- .../walker2d-v2/ft_awr_diffusion_mlp.yaml | 5 +- .../walker2d-v2/ft_dipo_diffusion_mlp.yaml | 5 +- .../walker2d-v2/ft_dql_diffusion_mlp.yaml | 5 +- .../walker2d-v2/ft_idql_diffusion_mlp.yaml | 7 +- .../walker2d-v2/ft_ppo_diffusion_mlp.yaml | 6 +- .../walker2d-v2/ft_qsm_diffusion_mlp.yaml | 5 +- .../walker2d-v2/ft_rwr_diffusion_mlp.yaml | 3 +- .../walker2d-v2/ppo_diffusion_mlp.yaml | 6 +- .../walker2d-v2/ppo_gaussian_mlp.yaml | 5 +- .../pre_diffusion_mlp.yaml | 4 +- .../hopper-medium-v2/pre_diffusion_mlp.yaml | 4 +- .../walker2d-medium-v2/pre_diffusion_mlp.yaml | 4 +- .../finetune/can/ft_awr_diffusion_mlp.yaml | 3 +- .../finetune/can/ft_dipo_diffusion_mlp.yaml | 3 +- .../finetune/can/ft_dql_diffusion_mlp.yaml | 3 +- .../finetune/can/ft_idql_diffusion_mlp.yaml | 5 +- .../finetune/can/ft_ppo_diffusion_mlp.yaml | 4 +- .../can/ft_ppo_diffusion_mlp_img.yaml | 9 +- .../finetune/can/ft_ppo_diffusion_unet.yaml | 4 +- .../can/ft_ppo_exact_diffusion_mlp.yaml | 4 +- .../finetune/can/ft_ppo_gaussian_mlp.yaml | 3 +- .../finetune/can/ft_ppo_gaussian_mlp_img.yaml | 8 +- .../can/ft_ppo_gaussian_transformer.yaml | 5 +- .../finetune/can/ft_ppo_gmm_mlp.yaml | 5 +- .../finetune/can/ft_ppo_gmm_transformer.yaml | 5 +- .../finetune/can/ft_qsm_diffusion_mlp.yaml | 3 +- .../finetune/can/ft_rwr_diffusion_mlp.yaml | 1 - .../finetune/lift/ft_awr_diffusion_mlp.yaml | 3 +- .../finetune/lift/ft_dipo_diffusion_mlp.yaml | 3 +- .../finetune/lift/ft_dql_diffusion_mlp.yaml | 3 +- .../finetune/lift/ft_idql_diffusion_mlp.yaml | 5 +- .../finetune/lift/ft_ppo_diffusion_mlp.yaml | 4 +- .../lift/ft_ppo_diffusion_mlp_img.yaml | 9 +- .../finetune/lift/ft_ppo_diffusion_unet.yaml | 4 +- .../finetune/lift/ft_ppo_gaussian_mlp.yaml | 3 +- .../lift/ft_ppo_gaussian_mlp_img.yaml | 8 +- .../lift/ft_ppo_gaussian_transformer.yaml | 3 +- .../finetune/lift/ft_ppo_gmm_mlp.yaml | 3 +- .../finetune/lift/ft_ppo_gmm_transformer.yaml | 3 +- .../finetune/lift/ft_qsm_diffusion_mlp.yaml | 3 +- .../finetune/lift/ft_rwr_diffusion_mlp.yaml | 1 - .../finetune/square/ft_awr_diffusion_mlp.yaml | 3 +- .../square/ft_dipo_diffusion_mlp.yaml | 3 +- .../finetune/square/ft_dql_diffusion_mlp.yaml | 3 +- .../square/ft_idql_diffusion_mlp.yaml | 5 +- .../finetune/square/ft_ppo_diffusion_mlp.yaml | 4 +- .../square/ft_ppo_diffusion_mlp_img.yaml | 9 +- .../square/ft_ppo_diffusion_unet.yaml | 4 +- .../finetune/square/ft_ppo_gaussian_mlp.yaml | 3 +- .../square/ft_ppo_gaussian_mlp_img.yaml | 8 +- .../square/ft_ppo_gaussian_transformer.yaml | 3 +- .../finetune/square/ft_ppo_gmm_mlp.yaml | 3 +- .../square/ft_ppo_gmm_transformer.yaml | 3 +- .../finetune/square/ft_qsm_diffusion_mlp.yaml | 3 +- .../finetune/square/ft_rwr_diffusion_mlp.yaml | 1 - .../transport/ft_awr_diffusion_mlp.yaml | 3 +- .../transport/ft_dipo_diffusion_mlp.yaml | 3 +- .../transport/ft_dql_diffusion_mlp.yaml | 3 +- .../transport/ft_idql_diffusion_mlp.yaml | 5 +- .../transport/ft_ppo_diffusion_mlp.yaml | 4 +- .../transport/ft_ppo_diffusion_mlp_img.yaml | 9 +- .../transport/ft_ppo_diffusion_unet.yaml | 4 +- .../transport/ft_ppo_gaussian_mlp.yaml | 3 +- .../transport/ft_ppo_gaussian_mlp_img.yaml | 8 +- .../ft_ppo_gaussian_transformer.yaml | 3 +- .../finetune/transport/ft_ppo_gmm_mlp.yaml | 3 +- .../transport/ft_ppo_gmm_transformer.yaml | 3 +- .../transport/ft_qsm_diffusion_mlp.yaml | 3 +- .../transport/ft_rwr_diffusion_mlp.yaml | 1 - .../pretrain/can/pre_diffusion_mlp.yaml | 4 +- .../pretrain/can/pre_diffusion_mlp_img.yaml | 8 +- .../pretrain/can/pre_diffusion_unet.yaml | 2 - .../pretrain/can/pre_gaussian_mlp_img.yaml | 4 + .../pretrain/lift/pre_diffusion_mlp.yaml | 2 - .../pretrain/lift/pre_diffusion_mlp_img.yaml | 8 +- .../pretrain/lift/pre_diffusion_unet.yaml | 4 +- .../pretrain/lift/pre_gaussian_mlp_img.yaml | 4 + .../pretrain/square/pre_diffusion_mlp.yaml | 2 - .../square/pre_diffusion_mlp_img.yaml | 8 +- .../pretrain/square/pre_diffusion_unet.yaml | 2 - .../pretrain/square/pre_gaussian_mlp_img.yaml | 4 + .../pretrain/transport/pre_diffusion_mlp.yaml | 2 - .../transport/pre_diffusion_mlp_img.yaml | 8 +- .../transport/pre_diffusion_unet.yaml | 2 - .../transport/pre_gaussian_mlp_img.yaml | 6 +- env/gym_utils/__init__.py | 2 +- env/gym_utils/wrapper/d3il_lowdim.py | 30 ++-- env/gym_utils/wrapper/furniture.py | 6 +- .../wrapper/mujoco_locomotion_lowdim.py | 18 ++- env/gym_utils/wrapper/multi_step.py | 22 +-- env/gym_utils/wrapper/robomimic_image.py | 4 +- env/gym_utils/wrapper/robomimic_lowdim.py | 41 +++-- model/common/critic.py | 97 ++++++++---- model/common/gaussian.py | 20 +-- model/common/gmm.py | 22 ++- model/common/mlp_gaussian.py | 52 ++++--- model/common/mlp_gmm.py | 16 +- model/common/modules.py | 18 +++ model/common/transformer.py | 36 +++-- model/common/vit.py | 48 +++--- model/diffusion/diffusion.py | 67 +------- model/diffusion/diffusion_dipo.py | 21 ++- model/diffusion/diffusion_dql.py | 32 ++-- model/diffusion/diffusion_idql.py | 91 +++++------ model/diffusion/diffusion_ppo.py | 19 ++- model/diffusion/diffusion_ppo_exact.py | 21 ++- model/diffusion/diffusion_qsm.py | 27 ++-- model/diffusion/diffusion_rwr.py | 20 +-- model/diffusion/diffusion_vpg.py | 147 ++++++++---------- model/diffusion/eta.py | 46 +++--- model/diffusion/exact_likelihood.py | 13 +- model/diffusion/mlp_diffusion.py | 76 +++++---- model/diffusion/sampling.py | 6 - model/diffusion/unet.py | 19 ++- model/rl/gaussian_ppo.py | 14 +- model/rl/gaussian_rwr.py | 15 +- model/rl/gaussian_vpg.py | 53 +++---- model/rl/gmm_ppo.py | 14 +- model/rl/gmm_vpg.py | 36 ++--- script/dataset/filter_d3il_avoid_data.py | 2 +- script/dataset/get_d4rl_dataset.py | 2 +- script/dataset/process_d3il_dataset.py | 2 +- script/dataset/process_robomimic_dataset.py | 2 +- 200 files changed, 1240 insertions(+), 1186 deletions(-) diff --git a/agent/dataset/sequence.py b/agent/dataset/sequence.py index 3a95fd3..f96200d 100644 --- a/agent/dataset/sequence.py +++ b/agent/dataset/sequence.py @@ -16,19 +16,21 @@ import random log = logging.getLogger(__name__) -Batch = namedtuple("Batch", "trajectories conditions") +Batch = namedtuple("Batch", "actions conditions") class StitchedSequenceDataset(torch.utils.data.Dataset): """ - Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories. + Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file. - Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is - (tuple of) dimension of observation, action, images, etc. + Use the first max_n_episodes episodes (instead of random sampling) Example: states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------] - Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------] + Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------] + + Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images. + """ def __init__( @@ -36,23 +38,30 @@ class StitchedSequenceDataset(torch.utils.data.Dataset): dataset_path, horizon_steps=64, cond_steps=1, + img_cond_steps=1, max_n_episodes=10000, use_img=False, device="cuda:0", ): + assert ( + img_cond_steps <= cond_steps + ), "consider using more cond_steps than img_cond_steps" self.horizon_steps = horizon_steps - self.cond_steps = cond_steps + self.cond_steps = cond_steps # states (proprio, etc.) + self.img_cond_steps = img_cond_steps self.device = device self.use_img = use_img # Load dataset to device specified if dataset_path.endswith(".npz"): dataset = np.load(dataset_path, allow_pickle=False) # only np arrays - else: + elif dataset_path.endswith(".pkl"): with open(dataset_path, "rb") as f: dataset = pickle.load(f) - traj_lengths = dataset["traj_lengths"] # 1-D array - total_num_steps = np.sum(traj_lengths[:max_n_episodes]) + else: + raise ValueError(f"Unsupported file format: {dataset_path}") + traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array + total_num_steps = np.sum(traj_lengths) # Set up indices for sampling self.indices = self.make_indices(traj_lengths, horizon_steps) @@ -75,35 +84,51 @@ class StitchedSequenceDataset(torch.utils.data.Dataset): log.info(f"Images shape/type: {self.images.shape, self.images.dtype}") def __getitem__(self, idx): - start = self.indices[idx] + """ + repeat states/images if using history observation at the beginning of the episode + """ + start, num_before_start = self.indices[idx] end = start + self.horizon_steps - states = self.states[start:end] + states = self.states[(start - num_before_start) : end] actions = self.actions[start:end] + states = torch.stack( + [ + states[min(num_before_start - t, 0)] + for t in reversed(range(self.cond_steps)) + ] + ) # more recent is at the end + conditions = {"state": states} if self.use_img: - images = self.images[start:end] - conditions = { - 1 - self.cond_steps: {"state": states[0], "rgb": images[0]} - } # TODO: allow obs history, -1, -2, ... - else: - conditions = {1 - self.cond_steps: states[0]} + images = self.images[(start - num_before_start) : end] + images = torch.stack( + [ + images[min(num_before_start - t, 0)] + for t in reversed(range(self.img_cond_steps)) + ] + ) + conditions["rgb"] = images batch = Batch(actions, conditions) return batch def make_indices(self, traj_lengths, horizon_steps): """ makes indices for sampling from dataset; - each index maps to a datapoint + each index maps to a datapoint, also save the number of steps before it within the same trajectory """ indices = [] cur_traj_index = 0 for traj_length in traj_lengths: max_start = cur_traj_index + traj_length - horizon_steps + 1 - indices += list(range(cur_traj_index, max_start)) + indices += [ + (i, i - cur_traj_index) for i in range(cur_traj_index, max_start) + ] cur_traj_index += traj_length return indices def set_train_val_split(self, train_split): - """Not doing validation right now""" + """ + Not doing validation right now + """ num_train = int(len(self.indices) * train_split) train_indices = random.sample(self.indices, num_train) val_indices = [i for i in range(len(self.indices)) if i not in train_indices] diff --git a/agent/finetune/train_awr_diffusion_agent.py b/agent/finetune/train_awr_diffusion_agent.py index 2c96556..4ec3cf1 100644 --- a/agent/finetune/train_awr_diffusion_agent.py +++ b/agent/finetune/train_awr_diffusion_agent.py @@ -3,6 +3,8 @@ Advantage-weighted regression (AWR) for diffusion policy. Advantage = discounted-reward-to-go - V(s) +Do not support pixel input right now. + """ import os @@ -131,6 +133,7 @@ class TrainAWRDiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -145,9 +148,10 @@ class TrainAWRDiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -155,7 +159,6 @@ class TrainAWRDiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env @@ -165,16 +168,19 @@ class TrainAWRDiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode, ) .cpu() .numpy() - ) # n_env x horizon x act + ) action_venv = samples[:, : self.act_steps] # Apply multi-step action @@ -184,7 +190,7 @@ class TrainAWRDiffusionAgent(TrainAgent): reward_trajs = np.vstack((reward_trajs, reward_venv[None])) # add to buffer - obs_buffer.append(prev_obs_venv) + obs_buffer.append(prev_obs_venv["state"]) action_buffer.append(action_venv) reward_buffer.append(reward_venv * self.scale_reward_factor) done_buffer.append(done_venv) @@ -230,59 +236,46 @@ class TrainAWRDiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: - - obs_trajs = np.array(deepcopy(obs_buffer)) + obs_trajs = np.array(deepcopy(obs_buffer)) # assume only state reward_trajs = np.array(deepcopy(reward_buffer)) dones_trajs = np.array(deepcopy(done_buffer)) - obs_t = einops.rearrange( torch.from_numpy(obs_trajs).float().to(self.device), "s e h d -> (s e) h d", ) - values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy()) - values_trajs = values_t.reshape(-1, self.n_envs) + values_trajs = np.array( + self.model.critic({"state": obs_t}).detach().cpu().numpy() + ).reshape(-1, self.n_envs) td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs) + td_t = torch.from_numpy(td_trajs.flatten()).float().to(self.device) - # flatten - obs_trajs = einops.rearrange( - obs_trajs, - "s e h d -> (s e) h d", - ) - td_trajs = einops.rearrange( - td_trajs, - "s e -> (s e)", - ) - - # Update policy and critic + # Update critic num_batch = int( self.n_steps * self.n_envs / self.batch_size * self.replay_ratio ) for _ in range(num_batch // self.critic_update_ratio): - - # Sample batch inds = np.random.choice(len(obs_trajs), self.batch_size) - obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device) - td_b = torch.from_numpy(td_trajs[inds]).float().to(self.device) - - # Update critic - loss_critic = self.model.loss_critic(obs_b, td_b) + loss_critic = self.model.loss_critic( + {"state": obs_t[inds]}, td_t[inds] + ) self.critic_optimizer.zero_grad() loss_critic.backward() self.critic_optimizer.step() + # Update policy - use a new copy of data obs_trajs = np.array(deepcopy(obs_buffer)) samples_trajs = np.array(deepcopy(action_buffer)) reward_trajs = np.array(deepcopy(reward_buffer)) dones_trajs = np.array(deepcopy(done_buffer)) - obs_t = einops.rearrange( torch.from_numpy(obs_trajs).float().to(self.device), "s e h d -> (s e) h d", ) - values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy()) - values_trajs = values_t.reshape(-1, self.n_envs) + values_trajs = np.array( + self.model.critic({"state": obs_t}).detach().cpu().numpy() + ).reshape(-1, self.n_envs) td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs) advantages_trajs = td_trajs - values_trajs @@ -304,7 +297,11 @@ class TrainAWRDiffusionAgent(TrainAgent): # Sample batch inds = np.random.choice(len(obs_trajs), self.batch_size) - obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device) + obs_b = { + "state": torch.from_numpy(obs_trajs[inds]) + .float() + .to(self.device) + } actions_b = ( torch.from_numpy(samples_trajs[inds]).float().to(self.device) ) @@ -347,6 +344,7 @@ class TrainAWRDiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -367,7 +365,7 @@ class TrainAWRDiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -383,7 +381,7 @@ class TrainAWRDiffusionAgent(TrainAgent): run_results[-1]["loss"] = loss run_results[-1]["loss_critic"] = loss_critic run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_dipo_diffusion_agent.py b/agent/finetune/train_dipo_diffusion_agent.py index 875f160..1726dff 100644 --- a/agent/finetune/train_dipo_diffusion_agent.py +++ b/agent/finetune/train_dipo_diffusion_agent.py @@ -4,6 +4,9 @@ Model-free online RL with DIffusion POlicy (DIPO) Applies action gradient to perturb actions towards maximizer of Q-function. a_t <- a_t + \eta * \grad_a Q(s, a) + +Do not support pixel input right now. + """ import os @@ -90,6 +93,7 @@ class TrainDIPODiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -104,9 +108,10 @@ class TrainDIPODiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -114,7 +119,6 @@ class TrainDIPODiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env @@ -124,11 +128,14 @@ class TrainDIPODiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode, ) .cpu() @@ -144,8 +151,8 @@ class TrainDIPODiffusionAgent(TrainAgent): # add to buffer for i in range(self.n_envs): - obs_buffer.append(prev_obs_venv[i]) - next_obs_buffer.append(obs_venv[i]) + obs_buffer.append(prev_obs_venv["state"][i]) + next_obs_buffer.append(obs_venv["state"][i]) action_buffer.append(action_venv[i]) reward_buffer.append(reward_venv[i] * self.scale_reward_factor) done_buffer.append(done_venv[i]) @@ -191,8 +198,8 @@ class TrainDIPODiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") + # Update models if not eval_mode: - num_batch = self.replay_ratio # Critic learning @@ -231,7 +238,12 @@ class TrainDIPODiffusionAgent(TrainAgent): # Update critic loss_critic = self.model.loss_critic( - obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma + {"state": obs_b}, + {"state": next_obs_b}, + actions_b, + rewards_b, + dones_b, + self.gamma, ) self.critic_optimizer.zero_grad() loss_critic.backward() @@ -239,7 +251,6 @@ class TrainDIPODiffusionAgent(TrainAgent): # Actor learning for _ in range(num_batch): - # Sample batch inds = np.random.choice(len(obs_buffer), self.batch_size) obs_b = ( @@ -265,7 +276,9 @@ class TrainDIPODiffusionAgent(TrainAgent): ) for _ in range(self.action_gradient_steps): actions_flat.requires_grad_(True) - q_values_1, q_values_2 = self.model.critic(obs_b, actions_flat) + q_values_1, q_values_2 = self.model.critic( + {"state": obs_b}, actions_flat + ) q_values = torch.min(q_values_1, q_values_2) action_opt_loss = -q_values.sum() @@ -291,7 +304,7 @@ class TrainDIPODiffusionAgent(TrainAgent): ) # Update policy with collected trajectories - loss = self.model.loss(guided_action.detach(), {0: obs_b}) + loss = self.model.loss(guided_action.detach(), {"state": obs_b}) self.actor_optimizer.zero_grad() loss.backward() if self.itr >= self.n_critic_warmup_itr: @@ -316,6 +329,7 @@ class TrainDIPODiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -336,7 +350,7 @@ class TrainDIPODiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -352,7 +366,7 @@ class TrainDIPODiffusionAgent(TrainAgent): run_results[-1]["loss"] = loss run_results[-1]["loss_critic"] = loss_critic run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_dql_diffusion_agent.py b/agent/finetune/train_dql_diffusion_agent.py index 1a90394..2e6e4e0 100644 --- a/agent/finetune/train_dql_diffusion_agent.py +++ b/agent/finetune/train_dql_diffusion_agent.py @@ -5,6 +5,9 @@ Learns a critic Q-function and backprops the expected Q-value to train the actor pi = argmin L_d(\theta) - \alpha * E[Q(s, a)] L_d is demonstration loss for regularization + +Do not support pixel input right now. + """ import os @@ -38,7 +41,6 @@ class TrainDQLDiffusionAgent(TrainAgent): lr=cfg.train.actor_lr, weight_decay=cfg.train.actor_weight_decay, ) - # use cosine scheduler with linear warmup self.actor_lr_scheduler = CosineAnnealingWarmupRestarts( self.actor_optimizer, first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps, @@ -88,6 +90,7 @@ class TrainDQLDiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -102,9 +105,10 @@ class TrainDQLDiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -112,7 +116,6 @@ class TrainDQLDiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env @@ -122,11 +125,14 @@ class TrainDQLDiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode, ) .cpu() @@ -142,8 +148,8 @@ class TrainDQLDiffusionAgent(TrainAgent): # add to buffer for i in range(self.n_envs): - obs_buffer.append(prev_obs_venv[i]) - next_obs_buffer.append(obs_venv[i]) + obs_buffer.append(prev_obs_venv["state"][i]) + next_obs_buffer.append(obs_venv["state"][i]) action_buffer.append(action_venv[i]) reward_buffer.append(reward_venv[i] * self.scale_reward_factor) done_buffer.append(done_venv[i]) @@ -189,8 +195,8 @@ class TrainDQLDiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") + # Update models if not eval_mode: - num_batch = self.replay_ratio # Critic learning @@ -229,7 +235,12 @@ class TrainDQLDiffusionAgent(TrainAgent): # Update critic loss_critic = self.model.loss_critic( - obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma + {"state": obs_b}, + {"state": next_obs_b}, + actions_b, + rewards_b, + dones_b, + self.gamma, ) self.critic_optimizer.zero_grad() loss_critic.backward() @@ -237,19 +248,21 @@ class TrainDQLDiffusionAgent(TrainAgent): # get the new action and q values samples = self.model.forward_train( - cond=obs_b.to(self.device), + cond={"state": obs_b}, deterministic=eval_mode, ) - output_venv = samples # n_env x horizon x act - action_venv = output_venv[:, : self.act_steps, : self.action_dim] - actions_flat_b = action_venv.reshape(action_venv.shape[0], -1) - q_values_b = self.model.critic(obs_b, actions_flat_b) + action_venv = samples[:, : self.act_steps] # n_env x horizon x act + q_values_b = self.model.critic({"state": obs_b}, action_venv) q1_new_action, q2_new_action = q_values_b # Update policy with collected trajectories self.actor_optimizer.zero_grad() actor_loss = self.model.loss_actor( - obs_b, actions_b, q1_new_action, q2_new_action, self.eta + {"state": obs_b}, + actions_b, + q1_new_action, + q2_new_action, + self.eta, ) actor_loss.backward() if self.itr >= self.n_critic_warmup_itr: @@ -275,6 +288,7 @@ class TrainDQLDiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -295,7 +309,7 @@ class TrainDQLDiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -311,7 +325,7 @@ class TrainDQLDiffusionAgent(TrainAgent): run_results[-1]["loss"] = loss run_results[-1]["loss_critic"] = loss_critic run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_idql_diffusion_agent.py b/agent/finetune/train_idql_diffusion_agent.py index 5b71169..e6dd168 100644 --- a/agent/finetune/train_idql_diffusion_agent.py +++ b/agent/finetune/train_idql_diffusion_agent.py @@ -1,6 +1,8 @@ """ Implicit diffusion Q-learning (IDQL) trainer for diffusion policy. +Do not support pixel input right now. + """ import os @@ -11,7 +13,6 @@ import torch import logging import wandb from copy import deepcopy -import random log = logging.getLogger(__name__) from util.timer import Timer @@ -98,8 +99,8 @@ class TrainIDQLDiffusionAgent(TrainAgent): # make a FIFO replay buffer for obs, action, and reward obs_buffer = deque(maxlen=self.buffer_size) - action_buffer = deque(maxlen=self.buffer_size) next_obs_buffer = deque(maxlen=self.buffer_size) + action_buffer = deque(maxlen=self.buffer_size) reward_buffer = deque(maxlen=self.buffer_size) done_buffer = deque(maxlen=self.buffer_size) first_buffer = deque(maxlen=self.buffer_size) @@ -107,6 +108,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -121,9 +123,10 @@ class TrainIDQLDiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -131,7 +134,6 @@ class TrainIDQLDiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env @@ -141,11 +143,14 @@ class TrainIDQLDiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode and self.eval_deterministic, num_sample=self.num_sample, use_expectile_exploration=self.use_expectile_exploration, @@ -162,9 +167,9 @@ class TrainIDQLDiffusionAgent(TrainAgent): reward_trajs = np.vstack((reward_trajs, reward_venv[None])) # add to buffer - obs_buffer.append(prev_obs_venv) + obs_buffer.append(prev_obs_venv["state"]) + next_obs_buffer.append(obs_venv["state"]) action_buffer.append(action_venv) - next_obs_buffer.append(obs_venv) reward_buffer.append(reward_venv * self.scale_reward_factor) done_buffer.append(done_venv) first_buffer.append(firsts_trajs[step]) @@ -209,7 +214,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: obs_trajs = np.array(deepcopy(obs_buffer)) @@ -257,14 +262,21 @@ class TrainIDQLDiffusionAgent(TrainAgent): done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device) # update critic value function - critic_loss_v = self.model.loss_critic_v(obs_b, actions_b) + critic_loss_v = self.model.loss_critic_v( + {"state": obs_b}, actions_b + ) self.critic_v_optimizer.zero_grad() critic_loss_v.backward() self.critic_v_optimizer.step() # update critic q function critic_loss_q = self.model.loss_critic_q( - obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma + {"state": obs_b}, + {"state": next_obs_b}, + actions_b, + reward_b, + done_b, + self.gamma, ) self.critic_q_optimizer.zero_grad() critic_loss_q.backward() @@ -278,7 +290,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): # Update policy with collected trajectories - no weighting loss = self.model.loss( actions_b, - obs_b, + {"state": obs_b}, ) self.actor_optimizer.zero_grad() loss.backward() @@ -305,6 +317,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -325,7 +338,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -341,7 +354,7 @@ class TrainIDQLDiffusionAgent(TrainAgent): run_results[-1]["loss"] = loss run_results[-1]["loss_critic"] = loss_critic run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_ppo_diffusion_agent.py b/agent/finetune/train_ppo_diffusion_agent.py index 380c723..3677b91 100644 --- a/agent/finetune/train_ppo_diffusion_agent.py +++ b/agent/finetune/train_ppo_diffusion_agent.py @@ -10,6 +10,7 @@ import numpy as np import torch import logging import wandb +import math log = logging.getLogger(__name__) from util.timer import Timer @@ -78,7 +79,9 @@ class TrainPPODiffusionAgent(TrainPPOAgent): ) # Holder - obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + obs_trajs = { + "state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + } chains_trajs = np.empty( ( 0, @@ -91,8 +94,8 @@ class TrainPPODiffusionAgent(TrainPPOAgent): reward_trajs = np.empty((0, self.n_envs)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.vstack( - (obs_full_trajs, prev_obs_venv[None].squeeze(2)) - ) # remove cond_step dim + (obs_full_trajs, prev_obs_venv["state"][:, -1][None]) + ) # save current obs # Collect a set of trajectories from env for step in range(self.n_steps): @@ -101,8 +104,13 @@ class TrainPPODiffusionAgent(TrainPPOAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = self.model( - cond=torch.from_numpy(prev_obs_venv).float().to(self.device), + cond=cond, deterministic=eval_mode, return_chain=True, ) @@ -118,14 +126,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent): obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) - if self.save_full_observations: - obs_full_venv = np.vstack( - [info["full_obs"][None] for info in info_venv] - ) # n_envs x n_act_steps x obs_dim + if self.save_full_observations: # state-only + obs_full_venv = np.array( + [info["full_obs"]["state"] for info in info_venv] + ) # n_envs x act_steps x obs_dim obs_full_trajs = np.vstack( (obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) ) - obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) + obs_trajs["state"] = np.vstack( + (obs_trajs["state"], prev_obs_venv["state"][None]) + ) chains_trajs = np.vstack((chains_trajs, chains_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None])) @@ -177,12 +187,22 @@ class TrainPPODiffusionAgent(TrainPPOAgent): # Update models if not eval_mode: with torch.no_grad(): - # Calculate value and logprobs - split into batches to prevent out of memory - obs_t = einops.rearrange( - torch.from_numpy(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", + obs_trajs["state"] = ( + torch.from_numpy(obs_trajs["state"]).float().to(self.device) ) - obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0) + + # Calculate value and logprobs - split into batches to prevent out of memory + num_split = math.ceil( + self.n_envs * self.n_steps / self.logprob_batch_size + ) + obs_ts = [{} for _ in range(num_split)] + obs_k = einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0) + for i, obs_t in enumerate(obs_ts_k): + obs_ts[i]["state"] = obs_t values_trajs = np.empty((0, self.n_envs)) for obs in obs_ts: values = self.model.critic(obs).cpu().numpy().flatten() @@ -219,7 +239,11 @@ class TrainPPODiffusionAgent(TrainPPOAgent): reward_trajs = reward_trajs_transpose.T # bootstrap value with GAE if not done - apply reward scaling with constant if specified - obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) + obs_venv_ts = { + "state": torch.from_numpy(obs_venv["state"]) + .float() + .to(self.device) + } with torch.no_grad(): next_value = ( self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() @@ -250,10 +274,12 @@ class TrainPPODiffusionAgent(TrainPPOAgent): returns_trajs = advantages_trajs + values_trajs # k for environment step - obs_k = einops.rearrange( - torch.tensor(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", - ) + obs_k = { + "state": einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + } chains_k = einops.rearrange( torch.tensor(chains_trajs).float().to(self.device), "s e t h d -> (s e) t h d", @@ -283,7 +309,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = obs_k[inds_b] + obs_b = {"state": obs_k["state"][inds_b]} chains_b = chains_k[inds_b] returns_b = returns_k[inds_b] values_b = values_k[inds_b] @@ -351,7 +377,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent): np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y ) - # Plot state trajectories in D3IL + # Plot state trajectories (only in D3IL) if ( self.itr % self.render_freq == 0 and self.n_render > 0 diff --git a/agent/finetune/train_ppo_diffusion_img_agent.py b/agent/finetune/train_ppo_diffusion_img_agent.py index 08ef33f..75bb101 100644 --- a/agent/finetune/train_ppo_diffusion_img_agent.py +++ b/agent/finetune/train_ppo_diffusion_img_agent.py @@ -30,7 +30,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): # Set obs dim - we will save the different obs in batch in a dict shape_meta = cfg.shape_meta - self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()} + self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs} # Gradient accumulation to deal with large GPU RAM usage self.grad_accumulate = cfg.train.grad_accumulate @@ -95,7 +95,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): key: torch.from_numpy(prev_obs_venv[key]) .float() .to(self.device) - for key in self.obs_dims.keys() + for key in self.obs_dims } # batch each type of obs and put into dict samples = self.model( cond=cond, @@ -114,7 +114,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) - for k in obs_trajs.keys(): + for k in obs_trajs: obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None])) chains_trajs = np.vstack((chains_trajs, chains_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) @@ -159,7 +159,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: with torch.no_grad(): # apply image randomization @@ -187,7 +187,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): self.n_envs * self.n_steps / self.logprob_batch_size ) obs_ts = [{} for _ in range(num_split)] - for k in obs_trajs.keys(): + for k in obs_trajs: obs_k = einops.rearrange( obs_trajs[k], "s e ... -> (s e) ...", @@ -238,7 +238,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): # bootstrap value with GAE if not done - apply reward scaling with constant if specified obs_venv_ts = { key: torch.from_numpy(obs_venv[key]).float().to(self.device) - for key in self.obs_dims.keys() + for key in self.obs_dims } with torch.no_grad(): next_value = ( @@ -278,7 +278,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): obs_trajs[k], "s e ... -> (s e) ...", ) - for k in obs_trajs.keys() + for k in obs_trajs } chains_k = einops.rearrange( torch.tensor(chains_trajs).float().to(self.device), @@ -309,7 +309,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()} + obs_b = {k: obs_k[k][inds_b] for k in obs_k} chains_b = chains_k[inds_b] returns_b = returns_k[inds_b] values_b = values_k[inds_b] @@ -387,7 +387,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y ) - # Update lr + # Update lr, min_sampling_std if self.itr >= self.n_critic_warmup_itr: self.actor_lr_scheduler.step() if self.learn_eta: @@ -407,6 +407,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -427,7 +428,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -462,7 +463,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent): run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["explained_variance"] = explained_var run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_ppo_exact_diffusion_agent.py b/agent/finetune/train_ppo_exact_diffusion_agent.py index a70541d..e71c00f 100644 --- a/agent/finetune/train_ppo_exact_diffusion_agent.py +++ b/agent/finetune/train_ppo_exact_diffusion_agent.py @@ -1,6 +1,8 @@ """ Use diffusion exact likelihood for policy gradient. +Do not support pixel input yet. + """ import os @@ -10,6 +12,7 @@ import numpy as np import torch import logging import wandb +import math log = logging.getLogger(__name__) from util.timer import Timer @@ -43,6 +46,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train + eval_mode = False self.model.eval() if eval_mode else self.model.train() last_itr_eval = eval_mode @@ -58,7 +62,9 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): ) # Holder - obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + obs_trajs = { + "state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + } samples_trajs = np.empty( ( 0, @@ -79,8 +85,8 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): reward_trajs = np.empty((0, self.n_envs)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.vstack( - (obs_full_trajs, prev_obs_venv[None].squeeze(2)) - ) # remove cond_step dim + (obs_full_trajs, prev_obs_venv["state"][:, -1][None]) + ) # save current obs # Collect a set of trajectories from env for step in range(self.n_steps): @@ -89,8 +95,13 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = self.model( - cond=torch.from_numpy(prev_obs_venv).float().to(self.device), + cond=cond, deterministic=eval_mode, return_chain=True, ) @@ -101,21 +112,23 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): samples.chains.cpu().numpy() ) # n_env x denoising x horizon x act action_venv = output_venv[:, : self.act_steps] + samples_trajs = np.vstack((samples_trajs, output_venv[None])) # Apply multi-step action obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) - if self.save_full_observations: - obs_full_venv = np.vstack( - [info["full_obs"][None] for info in info_venv] - ) # n_envs x n_act_steps x obs_dim + if self.save_full_observations: # state-only + obs_full_venv = np.array( + [info["full_obs"]["state"] for info in info_venv] + ) # n_envs x act_steps x obs_dim obs_full_trajs = np.vstack( (obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) ) - obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) + obs_trajs["state"] = np.vstack( + (obs_trajs["state"], prev_obs_venv["state"][None]) + ) chains_trajs = np.vstack((chains_trajs, chains_venv[None])) - samples_trajs = np.vstack((samples_trajs, output_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None])) firsts_trajs[step + 1] = done_venv @@ -158,15 +171,25 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: with torch.no_grad(): - # Calculate value and logprobs - split into batches to prevent out of memory - obs_t = einops.rearrange( - torch.from_numpy(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", + obs_trajs["state"] = ( + torch.from_numpy(obs_trajs["state"]).float().to(self.device) ) - obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0) + + # Calculate value and logprobs - split into batches to prevent out of memory + num_split = math.ceil( + self.n_envs * self.n_steps / self.logprob_batch_size + ) + obs_ts = [{} for _ in range(num_split)] + obs_k = einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0) + for i, obs_t in enumerate(obs_ts_k): + obs_ts[i]["state"] = obs_t values_trajs = np.empty((0, self.n_envs)) for obs in obs_ts: values = self.model.critic(obs).cpu().numpy().flatten() @@ -193,7 +216,11 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): reward_trajs = reward_trajs_transpose.T # bootstrap value with GAE if not done - apply reward scaling with constant if specified - obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) + obs_venv_ts = { + "state": torch.from_numpy(obs_venv["state"]) + .float() + .to(self.device) + } with torch.no_grad(): next_value = ( self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() @@ -224,10 +251,12 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): returns_trajs = advantages_trajs + values_trajs # k for environment step - obs_k = einops.rearrange( - torch.tensor(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", - ) + obs_k = { + "state": einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + } samples_k = einops.rearrange( torch.tensor(samples_trajs).float().to(self.device), "s e h d -> (s e) h d", @@ -257,7 +286,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = obs_k[inds_b] + obs_b = {"state": obs_k["state"][inds_b]} samples_b = samples_k[inds_b] returns_b = returns_k[inds_b] values_b = values_k[inds_b] @@ -318,7 +347,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y ) - # Plot state trajectories + # Plot state trajectories (only in D3IL) if ( self.itr % self.render_freq == 0 and self.n_render > 0 @@ -354,6 +383,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): run_results[-1]["chains_trajs"] = chains_trajs run_results[-1]["reward_trajs"] = reward_trajs if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -374,7 +404,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -399,7 +429,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent): run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["explained_variance"] = explained_var run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_ppo_gaussian_agent.py b/agent/finetune/train_ppo_gaussian_agent.py index 8803e10..cb37c0d 100644 --- a/agent/finetune/train_ppo_gaussian_agent.py +++ b/agent/finetune/train_ppo_gaussian_agent.py @@ -10,6 +10,7 @@ import numpy as np import torch import logging import wandb +import math log = logging.getLogger(__name__) from util.timer import Timer @@ -55,7 +56,9 @@ class TrainPPOGaussianAgent(TrainPPOAgent): ) # Holder - obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + obs_trajs = { + "state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + } samples_trajs = np.empty( ( 0, @@ -67,8 +70,8 @@ class TrainPPOGaussianAgent(TrainPPOAgent): reward_trajs = np.empty((0, self.n_envs)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.vstack( - (obs_full_trajs, prev_obs_venv[None].squeeze(2)) - ) # remove cond_step dim + (obs_full_trajs, prev_obs_venv["state"][:, -1][None]) + ) # save current obs # Collect a set of trajectories from env for step in range(self.n_steps): @@ -77,26 +80,33 @@ class TrainPPOGaussianAgent(TrainPPOAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = self.model( - cond=torch.from_numpy(prev_obs_venv).float().to(self.device), + cond=cond, deterministic=eval_mode, ) output_venv = samples.cpu().numpy() - action_venv = output_venv[:, : self.act_steps, : self.action_dim] - obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) - samples_trajs = np.vstack((samples_trajs, output_venv[None])) + action_venv = output_venv[:, : self.act_steps] # Apply multi-step action obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) - if self.save_full_observations: - obs_full_venv = np.vstack( - [info["full_obs"][None] for info in info_venv] - ) # n_envs x n_act_steps x obs_dim + if self.save_full_observations: # state-only + obs_full_venv = np.array( + [info["full_obs"]["state"] for info in info_venv] + ) # n_envs x act_steps x obs_dim obs_full_trajs = np.vstack( (obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) ) + obs_trajs["state"] = np.vstack( + (obs_trajs["state"], prev_obs_venv["state"][None]) + ) + samples_trajs = np.vstack((samples_trajs, output_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None])) firsts_trajs[step + 1] = done_venv @@ -144,15 +154,25 @@ class TrainPPOGaussianAgent(TrainPPOAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: with torch.no_grad(): - # Calculate value and logprobs - split into batches to prevent out of memory - obs_t = einops.rearrange( - torch.from_numpy(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", + obs_trajs["state"] = ( + torch.from_numpy(obs_trajs["state"]).float().to(self.device) ) - obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0) + + # Calculate value and logprobs - split into batches to prevent out of memory + num_split = math.ceil( + self.n_envs * self.n_steps / self.logprob_batch_size + ) + obs_ts = [{} for _ in range(num_split)] + obs_k = einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0) + for i, obs_t in enumerate(obs_ts_k): + obs_ts[i]["state"] = obs_t values_trajs = np.empty((0, self.n_envs)) for obs in obs_ts: values = self.model.critic(obs).cpu().numpy().flatten() @@ -184,7 +204,11 @@ class TrainPPOGaussianAgent(TrainPPOAgent): reward_trajs = reward_trajs_transpose.T # bootstrap value with GAE if not done - apply reward scaling with constant if specified - obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) + obs_venv_ts = { + "state": torch.from_numpy(obs_venv["state"]) + .float() + .to(self.device) + } with torch.no_grad(): next_value = ( self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() @@ -215,10 +239,12 @@ class TrainPPOGaussianAgent(TrainPPOAgent): returns_trajs = advantages_trajs + values_trajs # k for environment step - obs_k = einops.rearrange( - torch.tensor(obs_trajs).float().to(self.device), - "s e h d -> (s e) h d", - ) + obs_k = { + "state": einops.rearrange( + obs_trajs["state"], + "s e ... -> (s e) ...", + ) + } samples_k = einops.rearrange( torch.tensor(samples_trajs).float().to(self.device), "s e h d -> (s e) h d", @@ -250,7 +276,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = obs_k[inds_b] + obs_b = {"state": obs_k["state"][inds_b]} samples_b = samples_k[inds_b] returns_b = returns_k[inds_b] values_b = values_k[inds_b] @@ -313,7 +339,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent): np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y ) - # Plot state trajectories + # Plot state trajectories (only in D3IL) if ( self.itr % self.render_freq == 0 and self.n_render > 0 diff --git a/agent/finetune/train_ppo_gaussian_img_agent.py b/agent/finetune/train_ppo_gaussian_img_agent.py index 21251dd..37a2ed7 100644 --- a/agent/finetune/train_ppo_gaussian_img_agent.py +++ b/agent/finetune/train_ppo_gaussian_img_agent.py @@ -94,8 +94,8 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): key: torch.from_numpy(prev_obs_venv[key]) .float() .to(self.device) - for key in self.obs_dims.keys() - } # batch each type of obs and put into dict + for key in self.obs_dims + } samples = self.model( cond=cond, deterministic=eval_mode, @@ -107,7 +107,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) - for k in obs_trajs.keys(): + for k in obs_trajs: obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None])) samples_trajs = np.vstack((samples_trajs, output_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) @@ -152,7 +152,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: with torch.no_grad(): # apply image randomization @@ -180,7 +180,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): self.n_envs * self.n_steps / self.logprob_batch_size ) obs_ts = [{} for _ in range(num_split)] - for k in obs_trajs.keys(): + for k in obs_trajs: obs_k = einops.rearrange( obs_trajs[k], "s e ... -> (s e) ...", @@ -226,7 +226,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): # bootstrap value with GAE if not done - apply reward scaling with constant if specified obs_venv_ts = { key: torch.from_numpy(obs_venv[key]).float().to(self.device) - for key in self.obs_dims.keys() + for key in self.obs_dims } with torch.no_grad(): next_value = ( @@ -266,7 +266,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): obs_trajs[k], "s e ... -> (s e) ...", ) - for k in obs_trajs.keys() + for k in obs_trajs } samples_k = einops.rearrange( torch.tensor(samples_trajs).float().to(self.device), @@ -297,7 +297,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()} + obs_b = {k: obs_k[k][inds_b] for k in obs_k} samples_b = samples_k[inds_b] returns_b = returns_k[inds_b] values_b = values_k[inds_b] @@ -383,6 +383,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -403,7 +404,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -437,7 +438,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent): run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["explained_variance"] = explained_var run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_qsm_diffusion_agent.py b/agent/finetune/train_qsm_diffusion_agent.py index 3ddacd2..fa80dd9 100644 --- a/agent/finetune/train_qsm_diffusion_agent.py +++ b/agent/finetune/train_qsm_diffusion_agent.py @@ -1,6 +1,8 @@ """ QSM (Q-Score Matching) for diffusion policy. +Do not support pixel input right now. + """ import os @@ -75,8 +77,8 @@ class TrainQSMDiffusionAgent(TrainAgent): # make a FIFO replay buffer for obs, action, and reward obs_buffer = deque(maxlen=self.buffer_size) - action_buffer = deque(maxlen=self.buffer_size) next_obs_buffer = deque(maxlen=self.buffer_size) + action_buffer = deque(maxlen=self.buffer_size) reward_buffer = deque(maxlen=self.buffer_size) done_buffer = deque(maxlen=self.buffer_size) first_buffer = deque(maxlen=self.buffer_size) @@ -84,6 +86,7 @@ class TrainQSMDiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -98,9 +101,10 @@ class TrainQSMDiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -108,7 +112,6 @@ class TrainQSMDiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env @@ -118,11 +121,14 @@ class TrainQSMDiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode, ) .cpu() @@ -137,9 +143,9 @@ class TrainQSMDiffusionAgent(TrainAgent): reward_trajs = np.vstack((reward_trajs, reward_venv[None])) # add to buffer - obs_buffer.append(prev_obs_venv) + obs_buffer.append(prev_obs_venv["state"]) + next_obs_buffer.append(obs_venv["state"]) action_buffer.append(action_venv) - next_obs_buffer.append(obs_venv) reward_buffer.append(reward_venv * self.scale_reward_factor) done_buffer.append(done_venv) first_buffer.append(firsts_trajs[step]) @@ -184,7 +190,7 @@ class TrainQSMDiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: obs_trajs = np.array(deepcopy(obs_buffer)) @@ -233,7 +239,12 @@ class TrainQSMDiffusionAgent(TrainAgent): # update critic q function critic_loss = self.model.loss_critic( - obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma + {"state": obs_b}, + {"state": next_obs_b}, + actions_b, + reward_b, + done_b, + self.gamma, ) self.critic_optimizer.zero_grad() critic_loss.backward() @@ -246,7 +257,7 @@ class TrainQSMDiffusionAgent(TrainAgent): # Update policy with collected trajectories loss = self.model.loss_actor( - obs_b, + {"state": obs_b}, actions_b, self.q_grad_coeff, ) @@ -274,6 +285,7 @@ class TrainQSMDiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -294,7 +306,7 @@ class TrainQSMDiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -310,7 +322,7 @@ class TrainQSMDiffusionAgent(TrainAgent): run_results[-1]["loss"] = loss run_results[-1]["loss_critic"] = loss_critic run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/agent/finetune/train_rwr_diffusion_agent.py b/agent/finetune/train_rwr_diffusion_agent.py index 30d0daa..84ab2a5 100644 --- a/agent/finetune/train_rwr_diffusion_agent.py +++ b/agent/finetune/train_rwr_diffusion_agent.py @@ -1,6 +1,8 @@ """ Reward-weighted regression (RWR) for diffusion policy. +Do not support pixel input right now. + """ import os @@ -54,6 +56,7 @@ class TrainRWRDiffusionAgent(TrainAgent): # Start training loop timer = Timer() run_results = [] + last_itr_eval = False done_venv = np.zeros((1, self.n_envs)) while self.itr < self.n_train_itr: @@ -68,9 +71,10 @@ class TrainRWRDiffusionAgent(TrainAgent): # Define train or eval - all envs restart eval_mode = self.itr % self.val_freq == 0 and not self.force_train self.model.eval() if eval_mode else self.model.train() - firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) + last_itr_eval = eval_mode - # Reset env at the beginning of an iteration + # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode + firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) if self.reset_at_iteration or eval_mode or last_itr_eval: prev_obs_venv = self.reset_env_all(options_venv=options_venv) firsts_trajs[0] = 1 @@ -78,11 +82,11 @@ class TrainRWRDiffusionAgent(TrainAgent): firsts_trajs[0] = ( done_venv # if done at the end of last iteration, then the envs are just reset ) - last_itr_eval = eval_mode - reward_trajs = np.empty((0, self.n_envs)) - # Holders - obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + # Holder + obs_trajs = { + "state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) + } samples_trajs = np.empty( ( 0, @@ -91,6 +95,7 @@ class TrainRWRDiffusionAgent(TrainAgent): self.action_dim, ) ) + reward_trajs = np.empty((0, self.n_envs)) # Collect a set of trajectories from env for step in range(self.n_steps): @@ -99,24 +104,29 @@ class TrainRWRDiffusionAgent(TrainAgent): # Select action with torch.no_grad(): + cond = { + "state": torch.from_numpy(prev_obs_venv["state"]) + .float() + .to(self.device) + } samples = ( self.model( - cond=torch.from_numpy(prev_obs_venv) - .float() - .to(self.device), + cond=cond, deterministic=eval_mode, ) .cpu() .numpy() ) # n_env x horizon x act action_venv = samples[:, : self.act_steps] - obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) samples_trajs = np.vstack((samples_trajs, samples[None])) # Apply multi-step action obs_venv, reward_venv, done_venv, info_venv = self.venv.step( action_venv ) + obs_trajs["state"] = np.vstack( + (obs_trajs["state"], prev_obs_venv["state"][None]) + ) reward_trajs = np.vstack((reward_trajs, reward_venv[None])) firsts_trajs[step + 1] = done_venv prev_obs_venv = obs_venv @@ -133,7 +143,7 @@ class TrainRWRDiffusionAgent(TrainAgent): if len(episodes_start_end) > 0: # Compute transitions for completed trajectories obs_trajs_split = [ - obs_trajs[start : end + 1, env_ind] + {"state": obs_trajs["state"][start : end + 1, env_ind]} for env_ind, start, end in episodes_start_end ] samples_trajs_split = [ @@ -183,17 +193,20 @@ class TrainRWRDiffusionAgent(TrainAgent): success_rate = 0 log.info("[WARNING] No episode completed within the iteration!") - # Update + # Update models if not eval_mode: # Tensorize data and put them to device # k for environment step - obs_k = ( - torch.tensor(np.concatenate(obs_trajs_split)) + obs_k = { + "state": torch.tensor( + np.concatenate( + [obs_traj["state"] for obs_traj in obs_trajs_split] + ) + ) .float() .to(self.device) - ) - + } samples_k = ( torch.tensor(np.concatenate(samples_trajs_split)) .float() @@ -204,19 +217,15 @@ class TrainRWRDiffusionAgent(TrainAgent): returns_trajs_split = ( returns_trajs_split - np.mean(returns_trajs_split) ) / (returns_trajs_split.std() + 1e-3) - rewards_k = ( torch.tensor(returns_trajs_split) .float() .to(self.device) .reshape(-1) ) - rewards_k_scaled = torch.exp(self.beta * rewards_k) rewards_k_scaled.clamp_(max=self.max_reward_weight) - # rewards_k_scaled = rewards_k_scaled / rewards_k_scaled.mean() - # Update policy and critic total_steps = len(rewards_k_scaled) inds_k = np.arange(total_steps) @@ -229,7 +238,7 @@ class TrainRWRDiffusionAgent(TrainAgent): start = batch * self.batch_size end = start + self.batch_size inds_b = inds_k[start:end] # b for batch - obs_b = obs_k[inds_b] + obs_b = {"state": obs_k["state"][inds_b]} samples_b = samples_k[inds_b] rewards_b = rewards_k_scaled[inds_b] @@ -261,6 +270,7 @@ class TrainRWRDiffusionAgent(TrainAgent): } ) if self.itr % self.log_freq == 0: + time = timer() if eval_mode: log.info( f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" @@ -281,7 +291,7 @@ class TrainRWRDiffusionAgent(TrainAgent): run_results[-1]["eval_best_reward"] = avg_best_reward else: log.info( - f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" + f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}" ) if self.use_wandb: wandb.log( @@ -295,7 +305,7 @@ class TrainRWRDiffusionAgent(TrainAgent): ) run_results[-1]["loss"] = loss run_results[-1]["train_episode_reward"] = avg_episode_reward - run_results[-1]["time"] = timer() + run_results[-1]["time"] = time with open(self.result_path, "wb") as f: pickle.dump(run_results, f) self.itr += 1 diff --git a/cfg/d3il/finetune/avoid_m1/ft_ppo_diffusion_mlp.yaml b/cfg/d3il/finetune/avoid_m1/ft_ppo_diffusion_mlp.yaml index f3ea689..da3b099 100644 --- a/cfg/d3il/finetune/avoid_m1/ft_ppo_diffusion_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m1/ft_ppo_diffusion_mlp.yaml @@ -105,15 +105,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m1/ft_ppo_gaussian_mlp.yaml b/cfg/d3il/finetune/avoid_m1/ft_ppo_gaussian_mlp.yaml index 3bf6374..13526f1 100644 --- a/cfg/d3il/finetune/avoid_m1/ft_ppo_gaussian_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m1/ft_ppo_gaussian_mlp.yaml @@ -99,7 +99,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m1/ft_ppo_gmm_mlp.yaml b/cfg/d3il/finetune/avoid_m1/ft_ppo_gmm_mlp.yaml index 35a6a81..11e98dd 100644 --- a/cfg/d3il/finetune/avoid_m1/ft_ppo_gmm_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m1/ft_ppo_gmm_mlp.yaml @@ -100,7 +100,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m2/ft_ppo_diffusion_mlp.yaml b/cfg/d3il/finetune/avoid_m2/ft_ppo_diffusion_mlp.yaml index 9e69811..7dc7454 100644 --- a/cfg/d3il/finetune/avoid_m2/ft_ppo_diffusion_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m2/ft_ppo_diffusion_mlp.yaml @@ -105,15 +105,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m2/ft_ppo_gaussian_mlp.yaml b/cfg/d3il/finetune/avoid_m2/ft_ppo_gaussian_mlp.yaml index 539107f..86743b8 100644 --- a/cfg/d3il/finetune/avoid_m2/ft_ppo_gaussian_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m2/ft_ppo_gaussian_mlp.yaml @@ -99,7 +99,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m2/ft_ppo_gmm_mlp.yaml b/cfg/d3il/finetune/avoid_m2/ft_ppo_gmm_mlp.yaml index 5a8b776..960fc60 100644 --- a/cfg/d3il/finetune/avoid_m2/ft_ppo_gmm_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m2/ft_ppo_gmm_mlp.yaml @@ -100,7 +100,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m3/ft_ppo_diffusion_mlp.yaml b/cfg/d3il/finetune/avoid_m3/ft_ppo_diffusion_mlp.yaml index d350d5b..c22d961 100644 --- a/cfg/d3il/finetune/avoid_m3/ft_ppo_diffusion_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m3/ft_ppo_diffusion_mlp.yaml @@ -105,15 +105,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m3/ft_ppo_gaussian_mlp.yaml b/cfg/d3il/finetune/avoid_m3/ft_ppo_gaussian_mlp.yaml index a626f54..4147674 100644 --- a/cfg/d3il/finetune/avoid_m3/ft_ppo_gaussian_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m3/ft_ppo_gaussian_mlp.yaml @@ -99,7 +99,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/finetune/avoid_m3/ft_ppo_gmm_mlp.yaml b/cfg/d3il/finetune/avoid_m3/ft_ppo_gmm_mlp.yaml index 8f240ae..476fb33 100644 --- a/cfg/d3il/finetune/avoid_m3/ft_ppo_gmm_mlp.yaml +++ b/cfg/d3il/finetune/avoid_m3/ft_ppo_gmm_mlp.yaml @@ -100,7 +100,6 @@ model: _target_: model.common.critic.CriticObs mlp_dims: [256, 256, 256] residual_style: True - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/d3il/pretrain/avoid_m1/pre_diffusion_mlp.yaml b/cfg/d3il/pretrain/avoid_m1/pre_diffusion_mlp.yaml index 649e4d3..8ac6f58 100644 --- a/cfg/d3il/pretrain/avoid_m1/pre_diffusion_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m1/pre_diffusion_mlp.yaml @@ -54,9 +54,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/d3il/pretrain/avoid_m1/pre_gmm_mlp.yaml b/cfg/d3il/pretrain/avoid_m1/pre_gmm_mlp.yaml index 0204457..d980c50 100644 --- a/cfg/d3il/pretrain/avoid_m1/pre_gmm_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m1/pre_gmm_mlp.yaml @@ -47,7 +47,7 @@ model: residual_style: False fixed_std: 0.1 num_modes: ${num_modes} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} diff --git a/cfg/d3il/pretrain/avoid_m2/pre_diffusion_mlp.yaml b/cfg/d3il/pretrain/avoid_m2/pre_diffusion_mlp.yaml index 6565320..6ba8992 100644 --- a/cfg/d3il/pretrain/avoid_m2/pre_diffusion_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m2/pre_diffusion_mlp.yaml @@ -54,9 +54,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/d3il/pretrain/avoid_m2/pre_gmm_mlp.yaml b/cfg/d3il/pretrain/avoid_m2/pre_gmm_mlp.yaml index 1b63a83..52eb8f9 100644 --- a/cfg/d3il/pretrain/avoid_m2/pre_gmm_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m2/pre_gmm_mlp.yaml @@ -47,7 +47,7 @@ model: residual_style: False fixed_std: 0.1 num_modes: ${num_modes} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} diff --git a/cfg/d3il/pretrain/avoid_m3/pre_diffusion_mlp.yaml b/cfg/d3il/pretrain/avoid_m3/pre_diffusion_mlp.yaml index 9cba4ca..7567116 100644 --- a/cfg/d3il/pretrain/avoid_m3/pre_diffusion_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m3/pre_diffusion_mlp.yaml @@ -54,9 +54,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/d3il/pretrain/avoid_m3/pre_gmm_mlp.yaml b/cfg/d3il/pretrain/avoid_m3/pre_gmm_mlp.yaml index b577ef6..a9b24bf 100644 --- a/cfg/d3il/pretrain/avoid_m3/pre_gmm_mlp.yaml +++ b/cfg/d3il/pretrain/avoid_m3/pre_gmm_mlp.yaml @@ -47,7 +47,7 @@ model: residual_style: False fixed_std: 0.1 num_modes: ${num_modes} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} diff --git a/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_mlp.yaml index 227efdc..64ce4b4 100644 --- a/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_unet.yaml index 82652a8..deb8adc 100644 --- a/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/lamp_low/ft_ppo_diffusion_unet.yaml @@ -109,15 +109,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/lamp_low/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/lamp_low/ft_ppo_gaussian_mlp.yaml index 5efed86..b3d5002 100644 --- a/cfg/furniture/finetune/lamp_low/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/lamp_low/ft_ppo_gaussian_mlp.yaml @@ -95,15 +95,14 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_mlp.yaml index 1ddd372..ad81dcb 100644 --- a/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_unet.yaml index a600120..9131c2f 100644 --- a/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/lamp_med/ft_ppo_diffusion_unet.yaml @@ -108,15 +108,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/lamp_med/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/lamp_med/ft_ppo_gaussian_mlp.yaml index 7c538a9..49c3a50 100644 --- a/cfg/furniture/finetune/lamp_med/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/lamp_med/ft_ppo_gaussian_mlp.yaml @@ -95,15 +95,14 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_mlp.yaml index 2e5aecc..7ef5843 100644 --- a/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_unet.yaml index 0caead2..8f6088a 100644 --- a/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/one_leg_low/ft_ppo_diffusion_unet.yaml @@ -109,15 +109,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_low/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/one_leg_low/ft_ppo_gaussian_mlp.yaml index 31ea4ed..353322d 100644 --- a/cfg/furniture/finetune/one_leg_low/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/one_leg_low/ft_ppo_gaussian_mlp.yaml @@ -95,15 +95,14 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_mlp.yaml index c74a22d..7ad6681 100644 --- a/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_unet.yaml index 8a0c042..77fb4d4 100644 --- a/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/one_leg_med/ft_ppo_diffusion_unet.yaml @@ -109,15 +109,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/one_leg_med/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/one_leg_med/ft_ppo_gaussian_mlp.yaml index abd5c30..98ff894 100644 --- a/cfg/furniture/finetune/one_leg_med/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/one_leg_med/ft_ppo_gaussian_mlp.yaml @@ -95,15 +95,14 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_mlp.yaml index 5294173..d5b399f 100644 --- a/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_unet.yaml index 16c27bf..d873337 100644 --- a/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/round_table_low/ft_ppo_diffusion_unet.yaml @@ -109,15 +109,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_low/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/round_table_low/ft_ppo_gaussian_mlp.yaml index 0d4a2d9..15731da 100644 --- a/cfg/furniture/finetune/round_table_low/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/round_table_low/ft_ppo_gaussian_mlp.yaml @@ -100,10 +100,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_mlp.yaml b/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_mlp.yaml index aee7ec5..332649a 100644 --- a/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_mlp.yaml +++ b/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_mlp.yaml @@ -107,15 +107,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_unet.yaml b/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_unet.yaml index 72cf8a8..b4b6fac 100644 --- a/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_unet.yaml +++ b/cfg/furniture/finetune/round_table_med/ft_ppo_diffusion_unet.yaml @@ -108,15 +108,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/finetune/round_table_med/ft_ppo_gaussian_mlp.yaml b/cfg/furniture/finetune/round_table_med/ft_ppo_gaussian_mlp.yaml index 7aba0ae..ebad59a 100644 --- a/cfg/furniture/finetune/round_table_med/ft_ppo_gaussian_mlp.yaml +++ b/cfg/furniture/finetune/round_table_med/ft_ppo_gaussian_mlp.yaml @@ -95,15 +95,14 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [512, 512, 512] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/furniture/pretrain/lamp_low/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/lamp_low/pre_diffusion_mlp.yaml index e22a8d2..775916f 100644 --- a/cfg/furniture/pretrain/lamp_low/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/lamp_low/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/lamp_low/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/lamp_low/pre_diffusion_unet.yaml index 933a251..fc22513 100644 --- a/cfg/furniture/pretrain/lamp_low/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/lamp_low/pre_diffusion_unet.yaml @@ -58,9 +58,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/lamp_med/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/lamp_med/pre_diffusion_mlp.yaml index 8d21455..827739c 100644 --- a/cfg/furniture/pretrain/lamp_med/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/lamp_med/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/lamp_med/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/lamp_med/pre_diffusion_unet.yaml index 7f87d43..5d7bc9f 100644 --- a/cfg/furniture/pretrain/lamp_med/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/lamp_med/pre_diffusion_unet.yaml @@ -57,9 +57,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/one_leg_low/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/one_leg_low/pre_diffusion_mlp.yaml index 80dce0c..c996393 100644 --- a/cfg/furniture/pretrain/one_leg_low/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/one_leg_low/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/one_leg_low/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/one_leg_low/pre_diffusion_unet.yaml index cccb378..253b89b 100644 --- a/cfg/furniture/pretrain/one_leg_low/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/one_leg_low/pre_diffusion_unet.yaml @@ -58,9 +58,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/one_leg_med/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/one_leg_med/pre_diffusion_mlp.yaml index d8050cc..57bad4b 100644 --- a/cfg/furniture/pretrain/one_leg_med/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/one_leg_med/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/one_leg_med/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/one_leg_med/pre_diffusion_unet.yaml index 22ddcfb..d588590 100644 --- a/cfg/furniture/pretrain/one_leg_med/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/one_leg_med/pre_diffusion_unet.yaml @@ -57,9 +57,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/round_table_low/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/round_table_low/pre_diffusion_mlp.yaml index 6e32866..9930155 100644 --- a/cfg/furniture/pretrain/round_table_low/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/round_table_low/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/round_table_low/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/round_table_low/pre_diffusion_unet.yaml index bd3c29b..a9a9de1 100644 --- a/cfg/furniture/pretrain/round_table_low/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/round_table_low/pre_diffusion_unet.yaml @@ -57,9 +57,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/round_table_med/pre_diffusion_mlp.yaml b/cfg/furniture/pretrain/round_table_med/pre_diffusion_mlp.yaml index e7e3da9..a8f0d83 100644 --- a/cfg/furniture/pretrain/round_table_med/pre_diffusion_mlp.yaml +++ b/cfg/furniture/pretrain/round_table_med/pre_diffusion_mlp.yaml @@ -56,9 +56,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/furniture/pretrain/round_table_med/pre_diffusion_unet.yaml b/cfg/furniture/pretrain/round_table_med/pre_diffusion_unet.yaml index cd85ebd..e0c351a 100644 --- a/cfg/furniture/pretrain/round_table_med/pre_diffusion_unet.yaml +++ b/cfg/furniture/pretrain/round_table_med/pre_diffusion_unet.yaml @@ -57,9 +57,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_awr_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_awr_diffusion_mlp.yaml index 4e733df..4a99ae1 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_awr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_awr_diffusion_mlp.yaml @@ -83,20 +83,19 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_dipo_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_dipo_diffusion_mlp.yaml index b9641ef..f132db1 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_dipo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_dipo_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_dql_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_dql_diffusion_mlp.yaml index a26cbf7..48a907a 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_dql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_dql_diffusion_mlp.yaml @@ -81,7 +81,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -90,13 +90,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_idql_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_idql_diffusion_mlp.yaml index d45cfaf..2c82fdd 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_idql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_idql_diffusion_mlp.yaml @@ -84,7 +84,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -93,19 +93,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_ppo_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_ppo_diffusion_mlp.yaml index cad9e82..1363b2d 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_ppo_diffusion_mlp.yaml @@ -91,20 +91,18 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_ppo_exact_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_ppo_exact_diffusion_mlp.yaml index 49b45ab..2a71ecf 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_ppo_exact_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_ppo_exact_diffusion_mlp.yaml @@ -99,20 +99,18 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_qsm_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_qsm_diffusion_mlp.yaml index a36aee6..bdf4a1b 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_qsm_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_qsm_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ft_rwr_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ft_rwr_diffusion_mlp.yaml index 9c36ae1..ccb3a63 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ft_rwr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ft_rwr_diffusion_mlp.yaml @@ -74,7 +74,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -82,6 +82,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ppo_diffusion_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ppo_diffusion_mlp.yaml index 0c496ef..05a60c8 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ppo_diffusion_mlp.yaml @@ -87,22 +87,20 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/halfcheetah-v2/ppo_gaussian_mlp.yaml b/cfg/gym/finetune/halfcheetah-v2/ppo_gaussian_mlp.yaml index 2d3e9c9..4eb1cbb 100644 --- a/cfg/gym/finetune/halfcheetah-v2/ppo_gaussian_mlp.yaml +++ b/cfg/gym/finetune/halfcheetah-v2/ppo_gaussian_mlp.yaml @@ -80,15 +80,14 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_awr_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_awr_diffusion_mlp.yaml index b0239b6..6fb1f04 100644 --- a/cfg/gym/finetune/hopper-v2/ft_awr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_awr_diffusion_mlp.yaml @@ -83,20 +83,19 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_dipo_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_dipo_diffusion_mlp.yaml index 259f4df..0bee961 100644 --- a/cfg/gym/finetune/hopper-v2/ft_dipo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_dipo_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_dql_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_dql_diffusion_mlp.yaml index 4d8d9d5..4ca0a9d 100644 --- a/cfg/gym/finetune/hopper-v2/ft_dql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_dql_diffusion_mlp.yaml @@ -81,7 +81,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -90,13 +90,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_idql_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_idql_diffusion_mlp.yaml index 3d29412..eb345d8 100644 --- a/cfg/gym/finetune/hopper-v2/ft_idql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_idql_diffusion_mlp.yaml @@ -84,7 +84,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -93,19 +93,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_ppo_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_ppo_diffusion_mlp.yaml index 76364f8..3467f91 100644 --- a/cfg/gym/finetune/hopper-v2/ft_ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_ppo_diffusion_mlp.yaml @@ -91,20 +91,18 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_ppo_exact_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_ppo_exact_diffusion_mlp.yaml index 58cc997..3f44524 100644 --- a/cfg/gym/finetune/hopper-v2/ft_ppo_exact_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_ppo_exact_diffusion_mlp.yaml @@ -98,20 +98,18 @@ model: time_dim: 16 mlp_dims: [512, 512, 512] residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_qsm_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_qsm_diffusion_mlp.yaml index a627f94..3656300 100644 --- a/cfg/gym/finetune/hopper-v2/ft_qsm_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_qsm_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ft_rwr_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ft_rwr_diffusion_mlp.yaml index 7cdf82e..2cdc50d 100644 --- a/cfg/gym/finetune/hopper-v2/ft_rwr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ft_rwr_diffusion_mlp.yaml @@ -74,7 +74,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -82,6 +82,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ppo_diffusion_mlp.yaml b/cfg/gym/finetune/hopper-v2/ppo_diffusion_mlp.yaml index 8ce5397..89c75c1 100644 --- a/cfg/gym/finetune/hopper-v2/ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ppo_diffusion_mlp.yaml @@ -87,22 +87,20 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/hopper-v2/ppo_gaussian_mlp.yaml b/cfg/gym/finetune/hopper-v2/ppo_gaussian_mlp.yaml index 033cdbf..17865c5 100644 --- a/cfg/gym/finetune/hopper-v2/ppo_gaussian_mlp.yaml +++ b/cfg/gym/finetune/hopper-v2/ppo_gaussian_mlp.yaml @@ -80,15 +80,14 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_awr_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_awr_diffusion_mlp.yaml index 319d126..a529e3d 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_awr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_awr_diffusion_mlp.yaml @@ -83,20 +83,19 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_dipo_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_dipo_diffusion_mlp.yaml index dcdf747..7ef6db2 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_dipo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_dipo_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_dql_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_dql_diffusion_mlp.yaml index 9f147f5..e6b8d9e 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_dql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_dql_diffusion_mlp.yaml @@ -81,7 +81,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -90,13 +90,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_idql_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_idql_diffusion_mlp.yaml index d419dfe..3c46b89 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_idql_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_idql_diffusion_mlp.yaml @@ -84,7 +84,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -93,19 +93,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_ppo_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_ppo_diffusion_mlp.yaml index 2c25c54..00dd321 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_ppo_diffusion_mlp.yaml @@ -91,20 +91,18 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_qsm_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_qsm_diffusion_mlp.yaml index 3ac51e0..e91d9ff 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_qsm_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_qsm_diffusion_mlp.yaml @@ -82,7 +82,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -91,13 +91,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ft_rwr_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ft_rwr_diffusion_mlp.yaml index 2d097d1..6be6eb6 100644 --- a/cfg/gym/finetune/walker2d-v2/ft_rwr_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ft_rwr_diffusion_mlp.yaml @@ -74,7 +74,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -82,6 +82,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ppo_diffusion_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ppo_diffusion_mlp.yaml index 0ee1f77..dafbea6 100644 --- a/cfg/gym/finetune/walker2d-v2/ppo_diffusion_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ppo_diffusion_mlp.yaml @@ -87,22 +87,20 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/finetune/walker2d-v2/ppo_gaussian_mlp.yaml b/cfg/gym/finetune/walker2d-v2/ppo_gaussian_mlp.yaml index 1764603..a644eda 100644 --- a/cfg/gym/finetune/walker2d-v2/ppo_gaussian_mlp.yaml +++ b/cfg/gym/finetune/walker2d-v2/ppo_gaussian_mlp.yaml @@ -80,15 +80,14 @@ model: mlp_dims: [512, 512, 512] activation_type: ReLU residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/gym/pretrain/halfcheetah-medium-v2/pre_diffusion_mlp.yaml b/cfg/gym/pretrain/halfcheetah-medium-v2/pre_diffusion_mlp.yaml index 6c83483..6e20b5f 100644 --- a/cfg/gym/pretrain/halfcheetah-medium-v2/pre_diffusion_mlp.yaml +++ b/cfg/gym/pretrain/halfcheetah-medium-v2/pre_diffusion_mlp.yaml @@ -45,7 +45,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/gym/pretrain/hopper-medium-v2/pre_diffusion_mlp.yaml b/cfg/gym/pretrain/hopper-medium-v2/pre_diffusion_mlp.yaml index 236f13e..c1428a6 100644 --- a/cfg/gym/pretrain/hopper-medium-v2/pre_diffusion_mlp.yaml +++ b/cfg/gym/pretrain/hopper-medium-v2/pre_diffusion_mlp.yaml @@ -45,7 +45,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/gym/pretrain/walker2d-medium-v2/pre_diffusion_mlp.yaml b/cfg/gym/pretrain/walker2d-medium-v2/pre_diffusion_mlp.yaml index ab92bc0..893caa5 100644 --- a/cfg/gym/pretrain/walker2d-medium-v2/pre_diffusion_mlp.yaml +++ b/cfg/gym/pretrain/walker2d-medium-v2/pre_diffusion_mlp.yaml @@ -45,7 +45,7 @@ model: _target_: model.diffusion.mlp_diffusion.DiffusionMLP horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} time_dim: 16 mlp_dims: [512, 512, 512] activation_type: ReLU @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/finetune/can/ft_awr_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_awr_diffusion_mlp.yaml index c0916f6..4b095d6 100644 --- a/cfg/robomimic/finetune/can/ft_awr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_awr_diffusion_mlp.yaml @@ -94,13 +94,12 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_dipo_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_dipo_diffusion_mlp.yaml index b5c7867..a6876c4 100644 --- a/cfg/robomimic/finetune/can/ft_dipo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_dipo_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_dql_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_dql_diffusion_mlp.yaml index 5b94c6f..9086fe9 100644 --- a/cfg/robomimic/finetune/can/ft_dql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_dql_diffusion_mlp.yaml @@ -94,13 +94,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_idql_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_idql_diffusion_mlp.yaml index 4f03622..d0ee1c5 100644 --- a/cfg/robomimic/finetune/can/ft_idql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_idql_diffusion_mlp.yaml @@ -97,19 +97,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp.yaml index 988c2ef..e91c255 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp.yaml @@ -100,15 +100,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml index 5d9c9a8..d47e664 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_diffusion_mlp_img.yaml @@ -20,6 +20,7 @@ transition_dim: ${action_dim} denoising_steps: 100 ft_denoising_steps: 5 cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 use_ddim: True @@ -121,6 +122,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -133,6 +135,7 @@ model: time_dim: 32 mlp_dims: [512, 512, 512] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -143,6 +146,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -150,15 +154,14 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_diffusion_unet.yaml b/cfg/robomimic/finetune/can/ft_ppo_diffusion_unet.yaml index addda3d..6c94bf6 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_diffusion_unet.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_diffusion_unet.yaml @@ -103,15 +103,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_exact_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_ppo_exact_diffusion_mlp.yaml index 2995c60..617bdab 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_exact_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_exact_diffusion_mlp.yaml @@ -108,15 +108,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp.yaml b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp.yaml index f47e7a8..cc267e8 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp.yaml @@ -94,10 +94,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml index 0d94f40..9613b10 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gaussian_mlp_img.yaml @@ -18,6 +18,7 @@ obs_dim: 9 action_dim: 7 transition_dim: ${action_dim} cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 @@ -100,6 +101,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -115,6 +117,7 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -125,6 +128,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -132,10 +136,10 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_gaussian_transformer.yaml b/cfg/robomimic/finetune/can/ft_ppo_gaussian_transformer.yaml index 948e20f..c4930af 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gaussian_transformer.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gaussian_transformer.yaml @@ -26,7 +26,7 @@ env: name: ${env_name} best_reward_threshold_for_success: 1 max_episode_steps: 300 - save_video: false + save_video: False wrappers: robomimic_lowdim: normalization_path: ${normalization_path} @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_gmm_mlp.yaml b/cfg/robomimic/finetune/can/ft_ppo_gmm_mlp.yaml index fffe5bb..bdfe130 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gmm_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gmm_mlp.yaml @@ -27,7 +27,7 @@ env: name: ${env_name} best_reward_threshold_for_success: 1 max_episode_steps: 300 - save_video: false + save_video: False wrappers: robomimic_lowdim: normalization_path: ${normalization_path} @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_ppo_gmm_transformer.yaml b/cfg/robomimic/finetune/can/ft_ppo_gmm_transformer.yaml index 5bc098b..ea2a4ce 100644 --- a/cfg/robomimic/finetune/can/ft_ppo_gmm_transformer.yaml +++ b/cfg/robomimic/finetune/can/ft_ppo_gmm_transformer.yaml @@ -27,7 +27,7 @@ env: name: ${env_name} best_reward_threshold_for_success: 1 max_episode_steps: 300 - save_video: false + save_video: False wrappers: robomimic_lowdim: normalization_path: ${normalization_path} @@ -96,10 +96,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_qsm_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_qsm_diffusion_mlp.yaml index a590263..9d7396e 100644 --- a/cfg/robomimic/finetune/can/ft_qsm_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_qsm_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/can/ft_rwr_diffusion_mlp.yaml b/cfg/robomimic/finetune/can/ft_rwr_diffusion_mlp.yaml index 3ff9c2d..9fe8610 100644 --- a/cfg/robomimic/finetune/can/ft_rwr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/can/ft_rwr_diffusion_mlp.yaml @@ -86,6 +86,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_awr_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_awr_diffusion_mlp.yaml index 8a797e3..18d016f 100644 --- a/cfg/robomimic/finetune/lift/ft_awr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_awr_diffusion_mlp.yaml @@ -94,13 +94,12 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_dipo_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_dipo_diffusion_mlp.yaml index f3feb20..613e4ee 100644 --- a/cfg/robomimic/finetune/lift/ft_dipo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_dipo_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_dql_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_dql_diffusion_mlp.yaml index a9b0282..9705f53 100644 --- a/cfg/robomimic/finetune/lift/ft_dql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_dql_diffusion_mlp.yaml @@ -94,13 +94,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_idql_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_idql_diffusion_mlp.yaml index 563f9eb..ef5c502 100644 --- a/cfg/robomimic/finetune/lift/ft_idql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_idql_diffusion_mlp.yaml @@ -97,19 +97,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp.yaml index ffaf8b5..35a0318 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp.yaml @@ -100,15 +100,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml index 82a2449..fe36cb5 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_mlp_img.yaml @@ -20,6 +20,7 @@ transition_dim: ${action_dim} denoising_steps: 100 ft_denoising_steps: 5 cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 use_ddim: True @@ -121,6 +122,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -133,6 +135,7 @@ model: time_dim: 32 mlp_dims: [512, 512, 512] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -143,6 +146,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -150,15 +154,14 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_unet.yaml b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_unet.yaml index 3bc6975..451e9ec 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_diffusion_unet.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_diffusion_unet.yaml @@ -103,15 +103,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp.yaml index 9f6be1b..9134bcd 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp.yaml @@ -94,10 +94,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml index 7e055e8..3432e4b 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_mlp_img.yaml @@ -18,6 +18,7 @@ obs_dim: 9 action_dim: 7 transition_dim: ${action_dim} cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 @@ -100,6 +101,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -115,6 +117,7 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -125,6 +128,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -132,10 +136,10 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_transformer.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_transformer.yaml index c543cb7..3e32e7e 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gaussian_transformer.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gaussian_transformer.yaml @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gmm_mlp.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gmm_mlp.yaml index 4bd2406..84e0c78 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gmm_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gmm_mlp.yaml @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_ppo_gmm_transformer.yaml b/cfg/robomimic/finetune/lift/ft_ppo_gmm_transformer.yaml index d4469f3..bee15ff 100644 --- a/cfg/robomimic/finetune/lift/ft_ppo_gmm_transformer.yaml +++ b/cfg/robomimic/finetune/lift/ft_ppo_gmm_transformer.yaml @@ -96,10 +96,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_qsm_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_qsm_diffusion_mlp.yaml index ae5b684..6a7c1c7 100644 --- a/cfg/robomimic/finetune/lift/ft_qsm_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_qsm_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/lift/ft_rwr_diffusion_mlp.yaml b/cfg/robomimic/finetune/lift/ft_rwr_diffusion_mlp.yaml index e870256..ff0f7fa 100644 --- a/cfg/robomimic/finetune/lift/ft_rwr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/lift/ft_rwr_diffusion_mlp.yaml @@ -86,6 +86,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_awr_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_awr_diffusion_mlp.yaml index afcf62b..7a44806 100644 --- a/cfg/robomimic/finetune/square/ft_awr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_awr_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_dipo_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_dipo_diffusion_mlp.yaml index 5f2baa3..6ab3d87 100644 --- a/cfg/robomimic/finetune/square/ft_dipo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_dipo_diffusion_mlp.yaml @@ -96,13 +96,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_dql_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_dql_diffusion_mlp.yaml index d834dc9..f545be8 100644 --- a/cfg/robomimic/finetune/square/ft_dql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_dql_diffusion_mlp.yaml @@ -95,13 +95,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_idql_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_idql_diffusion_mlp.yaml index 9245579..48adb2a 100644 --- a/cfg/robomimic/finetune/square/ft_idql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_idql_diffusion_mlp.yaml @@ -98,19 +98,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp.yaml index ee41faf..f8b8bbc 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp.yaml @@ -101,15 +101,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml index ffc6320..1e8ebfa 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_diffusion_mlp_img.yaml @@ -20,6 +20,7 @@ transition_dim: ${action_dim} denoising_steps: 100 ft_denoising_steps: 5 cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 use_ddim: True @@ -121,6 +122,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -133,6 +135,7 @@ model: time_dim: 32 mlp_dims: [768, 768, 768] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -143,6 +146,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -150,15 +154,14 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_diffusion_unet.yaml b/cfg/robomimic/finetune/square/ft_ppo_diffusion_unet.yaml index 5020ef2..10aa67a 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_diffusion_unet.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_diffusion_unet.yaml @@ -103,15 +103,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp.yaml b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp.yaml index 9f7ec25..d705195 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp.yaml @@ -94,10 +94,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml index 5a96b16..7bbe198 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gaussian_mlp_img.yaml @@ -18,6 +18,7 @@ obs_dim: 9 action_dim: 7 transition_dim: ${action_dim} cond_steps: 1 +img_cond_steps: 1 horizon_steps: 4 act_steps: 4 @@ -100,6 +101,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -115,6 +117,7 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -125,6 +128,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -132,10 +136,10 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_gaussian_transformer.yaml b/cfg/robomimic/finetune/square/ft_ppo_gaussian_transformer.yaml index 9bf6f10..ac4cb99 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gaussian_transformer.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gaussian_transformer.yaml @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_gmm_mlp.yaml b/cfg/robomimic/finetune/square/ft_ppo_gmm_mlp.yaml index 62ca68a..2f85676 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gmm_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gmm_mlp.yaml @@ -95,10 +95,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_ppo_gmm_transformer.yaml b/cfg/robomimic/finetune/square/ft_ppo_gmm_transformer.yaml index e185604..d0c82db 100644 --- a/cfg/robomimic/finetune/square/ft_ppo_gmm_transformer.yaml +++ b/cfg/robomimic/finetune/square/ft_ppo_gmm_transformer.yaml @@ -96,10 +96,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_qsm_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_qsm_diffusion_mlp.yaml index b1eb458..90e72d7 100644 --- a/cfg/robomimic/finetune/square/ft_qsm_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_qsm_diffusion_mlp.yaml @@ -96,13 +96,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/square/ft_rwr_diffusion_mlp.yaml b/cfg/robomimic/finetune/square/ft_rwr_diffusion_mlp.yaml index 1c731ef..dbc8924 100644 --- a/cfg/robomimic/finetune/square/ft_rwr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/square/ft_rwr_diffusion_mlp.yaml @@ -87,6 +87,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_awr_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_awr_diffusion_mlp.yaml index d3995ef..025fb48 100644 --- a/cfg/robomimic/finetune/transport/ft_awr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_awr_diffusion_mlp.yaml @@ -97,13 +97,12 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_dipo_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_dipo_diffusion_mlp.yaml index a7690fd..91d8e96 100644 --- a/cfg/robomimic/finetune/transport/ft_dipo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_dipo_diffusion_mlp.yaml @@ -98,13 +98,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_dql_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_dql_diffusion_mlp.yaml index 40a3588..72041f4 100644 --- a/cfg/robomimic/finetune/transport/ft_dql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_dql_diffusion_mlp.yaml @@ -97,13 +97,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_idql_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_idql_diffusion_mlp.yaml index 36eabe5..9b670b9 100644 --- a/cfg/robomimic/finetune/transport/ft_idql_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_idql_diffusion_mlp.yaml @@ -100,19 +100,18 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True critic_v: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp.yaml index f962739..e887fce 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp.yaml @@ -103,15 +103,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml index c5b5d64..87f336a 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_mlp_img.yaml @@ -20,6 +20,7 @@ transition_dim: ${action_dim} denoising_steps: 100 ft_denoising_steps: 5 cond_steps: 1 +img_cond_steps: 1 horizon_steps: 8 act_steps: 8 use_ddim: True @@ -125,6 +126,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -138,6 +140,7 @@ model: time_dim: 32 mlp_dims: [768, 768, 768] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -149,6 +152,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -156,15 +160,14 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_unet.yaml b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_unet.yaml index b9c7f15..d17e184 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_diffusion_unet.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_diffusion_unet.yaml @@ -106,15 +106,13 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True ft_denoising_steps: ${ft_denoising_steps} - transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - cond_steps: ${cond_steps} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp.yaml index 85b7db3..9a5650a 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp.yaml @@ -97,10 +97,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml index 24a5696..9503c1d 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_mlp_img.yaml @@ -18,6 +18,7 @@ obs_dim: 18 action_dim: 14 transition_dim: ${action_dim} cond_steps: 1 +img_cond_steps: 1 horizon_steps: 8 act_steps: 8 @@ -104,6 +105,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -120,6 +122,7 @@ model: learn_fixed_std: True std_min: 0.01 std_max: 0.2 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -131,6 +134,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -138,10 +142,10 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 - obs_dim: ${obs_dim} + img_cond_steps: ${img_cond_steps} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_transformer.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_transformer.yaml index bbab06d..9d4eaa0 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gaussian_transformer.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gaussian_transformer.yaml @@ -98,10 +98,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gmm_mlp.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gmm_mlp.yaml index c961d88..07e3cf5 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gmm_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gmm_mlp.yaml @@ -98,10 +98,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_ppo_gmm_transformer.yaml b/cfg/robomimic/finetune/transport/ft_ppo_gmm_transformer.yaml index 32f0d61..b35c5dc 100644 --- a/cfg/robomimic/finetune/transport/ft_ppo_gmm_transformer.yaml +++ b/cfg/robomimic/finetune/transport/ft_ppo_gmm_transformer.yaml @@ -99,10 +99,9 @@ model: transition_dim: ${transition_dim} critic: _target_: model.common.critic.CriticObs - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} - cond_steps: ${cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_qsm_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_qsm_diffusion_mlp.yaml index 235a747..60fc95c 100644 --- a/cfg/robomimic/finetune/transport/ft_qsm_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_qsm_diffusion_mlp.yaml @@ -98,13 +98,12 @@ model: _target_: model.common.critic.CriticObsAct action_dim: ${action_dim} action_steps: ${act_steps} - obs_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} mlp_dims: [256, 256, 256] activation_type: Mish residual_style: True horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/finetune/transport/ft_rwr_diffusion_mlp.yaml b/cfg/robomimic/finetune/transport/ft_rwr_diffusion_mlp.yaml index a70659d..016ccee 100644 --- a/cfg/robomimic/finetune/transport/ft_rwr_diffusion_mlp.yaml +++ b/cfg/robomimic/finetune/transport/ft_rwr_diffusion_mlp.yaml @@ -89,6 +89,5 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/can/pre_diffusion_mlp.yaml b/cfg/robomimic/pretrain/can/pre_diffusion_mlp.yaml index 8c9c588..80a5a6d 100644 --- a/cfg/robomimic/pretrain/can/pre_diffusion_mlp.yaml +++ b/cfg/robomimic/pretrain/can/pre_diffusion_mlp.yaml @@ -46,15 +46,13 @@ model: time_dim: 16 mlp_dims: [512, 512, 512] residual_style: True - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/can/pre_diffusion_mlp_img.yaml b/cfg/robomimic/pretrain/can/pre_diffusion_mlp_img.yaml index 5761803..b91268e 100644 --- a/cfg/robomimic/pretrain/can/pre_diffusion_mlp_img.yaml +++ b/cfg/robomimic/pretrain/can/pre_diffusion_mlp_img.yaml @@ -18,6 +18,7 @@ transition_dim: ${action_dim} denoising_steps: 100 horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -27,7 +28,7 @@ wandb: shape_meta: obs: rgb: - shape: [3, 96, 96] + shape: [3, 96, 96] # not counting img_cond_steps state: shape: [9] action: @@ -55,6 +56,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -62,6 +64,7 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 + img_cond_steps: ${img_cond_steps} augment: True spatial_emb: 128 time_dim: 32 @@ -73,9 +76,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: @@ -88,4 +89,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/can/pre_diffusion_unet.yaml b/cfg/robomimic/pretrain/can/pre_diffusion_unet.yaml index bc73a00..204ebf0 100644 --- a/cfg/robomimic/pretrain/can/pre_diffusion_unet.yaml +++ b/cfg/robomimic/pretrain/can/pre_diffusion_unet.yaml @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/can/pre_gaussian_mlp_img.yaml b/cfg/robomimic/pretrain/can/pre_gaussian_mlp_img.yaml index d14b255..0c8c8d6 100644 --- a/cfg/robomimic/pretrain/can/pre_gaussian_mlp_img.yaml +++ b/cfg/robomimic/pretrain/can/pre_gaussian_mlp_img.yaml @@ -17,6 +17,7 @@ action_dim: 7 transition_dim: ${action_dim} horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -52,6 +53,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -64,6 +66,7 @@ model: mlp_dims: [512, 512, 512] residual_style: True fixed_std: 0.1 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -80,4 +83,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/lift/pre_diffusion_mlp.yaml b/cfg/robomimic/pretrain/lift/pre_diffusion_mlp.yaml index a0aad6a..63057ea 100644 --- a/cfg/robomimic/pretrain/lift/pre_diffusion_mlp.yaml +++ b/cfg/robomimic/pretrain/lift/pre_diffusion_mlp.yaml @@ -52,9 +52,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/lift/pre_diffusion_mlp_img.yaml b/cfg/robomimic/pretrain/lift/pre_diffusion_mlp_img.yaml index 2871525..5a8de94 100644 --- a/cfg/robomimic/pretrain/lift/pre_diffusion_mlp_img.yaml +++ b/cfg/robomimic/pretrain/lift/pre_diffusion_mlp_img.yaml @@ -18,6 +18,7 @@ transition_dim: ${action_dim} denoising_steps: 100 horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -27,7 +28,7 @@ wandb: shape_meta: obs: rgb: - shape: [3, 96, 96] + shape: [3, 96, 96] # not counting img_cond_steps state: shape: [9] action: @@ -55,6 +56,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -62,6 +64,7 @@ model: num_heads: 4 embed_style: embed2 embed_norm: 0 + img_cond_steps: ${img_cond_steps} augment: True spatial_emb: 128 time_dim: 32 @@ -73,9 +76,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: @@ -88,4 +89,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/lift/pre_diffusion_unet.yaml b/cfg/robomimic/pretrain/lift/pre_diffusion_unet.yaml index f1872b1..2322762 100644 --- a/cfg/robomimic/pretrain/lift/pre_diffusion_unet.yaml +++ b/cfg/robomimic/pretrain/lift/pre_diffusion_unet.yaml @@ -51,13 +51,11 @@ model: smaller_encoder: False cond_predict_scale: True transition_dim: ${transition_dim} - cond_dim: ${obs_dim} + cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/lift/pre_gaussian_mlp_img.yaml b/cfg/robomimic/pretrain/lift/pre_gaussian_mlp_img.yaml index 1f711b9..af0f065 100644 --- a/cfg/robomimic/pretrain/lift/pre_gaussian_mlp_img.yaml +++ b/cfg/robomimic/pretrain/lift/pre_gaussian_mlp_img.yaml @@ -17,6 +17,7 @@ action_dim: 7 transition_dim: ${action_dim} horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -52,6 +53,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -64,6 +66,7 @@ model: mlp_dims: [512, 512, 512] residual_style: True fixed_std: 0.1 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -80,4 +83,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/square/pre_diffusion_mlp.yaml b/cfg/robomimic/pretrain/square/pre_diffusion_mlp.yaml index 85540f8..1f9a13b 100644 --- a/cfg/robomimic/pretrain/square/pre_diffusion_mlp.yaml +++ b/cfg/robomimic/pretrain/square/pre_diffusion_mlp.yaml @@ -53,9 +53,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/square/pre_diffusion_mlp_img.yaml b/cfg/robomimic/pretrain/square/pre_diffusion_mlp_img.yaml index 7ca5e70..f6853f4 100644 --- a/cfg/robomimic/pretrain/square/pre_diffusion_mlp_img.yaml +++ b/cfg/robomimic/pretrain/square/pre_diffusion_mlp_img.yaml @@ -18,6 +18,7 @@ transition_dim: ${action_dim} denoising_steps: 100 horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -27,7 +28,7 @@ wandb: shape_meta: obs: rgb: - shape: [3, 96, 96] + shape: [3, 96, 96] # not counting img_cond_steps state: shape: [9] action: @@ -55,6 +56,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -67,15 +69,14 @@ model: time_dim: 32 mlp_dims: [768, 768, 768] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: @@ -88,4 +89,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/square/pre_diffusion_unet.yaml b/cfg/robomimic/pretrain/square/pre_diffusion_unet.yaml index cb3cac1..f74397c 100644 --- a/cfg/robomimic/pretrain/square/pre_diffusion_unet.yaml +++ b/cfg/robomimic/pretrain/square/pre_diffusion_unet.yaml @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/square/pre_gaussian_mlp_img.yaml b/cfg/robomimic/pretrain/square/pre_gaussian_mlp_img.yaml index 9ae3035..065ff91 100644 --- a/cfg/robomimic/pretrain/square/pre_gaussian_mlp_img.yaml +++ b/cfg/robomimic/pretrain/square/pre_gaussian_mlp_img.yaml @@ -17,6 +17,7 @@ action_dim: 7 transition_dim: ${action_dim} horizon_steps: 4 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -52,6 +53,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -64,6 +66,7 @@ model: mlp_dims: [768, 768, 768] residual_style: True fixed_std: 0.1 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -80,4 +83,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/transport/pre_diffusion_mlp.yaml b/cfg/robomimic/pretrain/transport/pre_diffusion_mlp.yaml index 6ecf14a..9d2425f 100644 --- a/cfg/robomimic/pretrain/transport/pre_diffusion_mlp.yaml +++ b/cfg/robomimic/pretrain/transport/pre_diffusion_mlp.yaml @@ -52,9 +52,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/transport/pre_diffusion_mlp_img.yaml b/cfg/robomimic/pretrain/transport/pre_diffusion_mlp_img.yaml index 7731fcc..2da0e07 100644 --- a/cfg/robomimic/pretrain/transport/pre_diffusion_mlp_img.yaml +++ b/cfg/robomimic/pretrain/transport/pre_diffusion_mlp_img.yaml @@ -18,6 +18,7 @@ transition_dim: ${action_dim} denoising_steps: 100 horizon_steps: 8 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -27,7 +28,7 @@ wandb: shape_meta: obs: rgb: - shape: [3, 96, 96] + shape: [3, 96, 96] # not counting img_cond_steps state: shape: [9] action: @@ -55,6 +56,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -68,15 +70,14 @@ model: time_dim: 32 mlp_dims: [768, 768, 768] residual_style: True + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: @@ -89,4 +90,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/cfg/robomimic/pretrain/transport/pre_diffusion_unet.yaml b/cfg/robomimic/pretrain/transport/pre_diffusion_unet.yaml index 55c691a..f525a21 100644 --- a/cfg/robomimic/pretrain/transport/pre_diffusion_unet.yaml +++ b/cfg/robomimic/pretrain/transport/pre_diffusion_unet.yaml @@ -55,9 +55,7 @@ model: horizon_steps: ${horizon_steps} obs_dim: ${obs_dim} action_dim: ${action_dim} - transition_dim: ${transition_dim} denoising_steps: ${denoising_steps} - cond_steps: ${cond_steps} device: ${device} ema: diff --git a/cfg/robomimic/pretrain/transport/pre_gaussian_mlp_img.yaml b/cfg/robomimic/pretrain/transport/pre_gaussian_mlp_img.yaml index 9d475aa..1f6ba93 100644 --- a/cfg/robomimic/pretrain/transport/pre_gaussian_mlp_img.yaml +++ b/cfg/robomimic/pretrain/transport/pre_gaussian_mlp_img.yaml @@ -17,6 +17,7 @@ action_dim: 14 transition_dim: ${action_dim} horizon_steps: 8 cond_steps: 1 +img_cond_steps: 1 wandb: entity: ${oc.env:DPPO_WANDB_ENTITY} @@ -26,7 +27,7 @@ wandb: shape_meta: obs: rgb: - shape: [3, 96, 96] + shape: [3, 96, 96] # not counting img_cond_steps state: shape: [9] action: @@ -52,6 +53,7 @@ model: backbone: _target_: model.common.vit.VitEncoder obs_shape: ${shape_meta.obs.rgb.shape} + num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated cfg: patch_size: 8 depth: 1 @@ -65,6 +67,7 @@ model: mlp_dims: [768, 768, 768] residual_style: True fixed_std: 0.1 + img_cond_steps: ${img_cond_steps} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} horizon_steps: ${horizon_steps} transition_dim: ${transition_dim} @@ -81,4 +84,5 @@ train_dataset: horizon_steps: ${horizon_steps} max_n_episodes: 100 cond_steps: ${cond_steps} + img_cond_steps: ${img_cond_steps} device: ${device} \ No newline at end of file diff --git a/env/gym_utils/__init__.py b/env/gym_utils/__init__.py index 632b6a4..c29b7f0 100644 --- a/env/gym_utils/__init__.py +++ b/env/gym_utils/__init__.py @@ -213,7 +213,7 @@ def make_async( "render.modes": ["human", "rgb_array", "depth_array"], "video.frames_per_second": 12, } - return MultiStep(env=env) # use all defaults + return MultiStep(env=env, n_obs_steps=wrappers.multi_step.n_obs_steps) env_fns = [_make_env for _ in range(num_envs)] return ( diff --git a/env/gym_utils/wrapper/d3il_lowdim.py b/env/gym_utils/wrapper/d3il_lowdim.py index 80d5591..4e1f2f8 100644 --- a/env/gym_utils/wrapper/d3il_lowdim.py +++ b/env/gym_utils/wrapper/d3il_lowdim.py @@ -1,10 +1,12 @@ """ Environment wrapper for D3IL environments with state observations. +For consistency, we will use Dict{} for the observation space, with the key "state" for the state observation. """ import numpy as np import gym +from gym import spaces class D3ilLowdimWrapper(gym.Env): @@ -12,28 +14,27 @@ class D3ilLowdimWrapper(gym.Env): self, env, normalization_path, - # init_state=None, - # render_hw=(256, 256), - # render_camera_name="agentview", ): self.env = env - # self.init_state = init_state - # self.render_hw = render_hw - # self.render_camera_name = render_camera_name # setup spaces self.action_space = env.action_space - self.observation_space = env.observation_space normalization = np.load(normalization_path) self.obs_min = normalization["obs_min"] self.obs_max = normalization["obs_max"] self.action_min = normalization["action_min"] self.action_max = normalization["action_max"] - # def get_observation(self): - # raw_obs = self.env.get_observation() - # obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) - # return obs + self.observation_space = spaces.Dict() + obs_example = self.env.reset() + low = np.full_like(obs_example, fill_value=-1) + high = np.full_like(obs_example, fill_value=1) + self.observation_space["state"] = spaces.Box( + low=low, + high=high, + shape=low.shape, + dtype=low.dtype, + ) def seed(self, seed=None): if seed is not None: @@ -48,9 +49,6 @@ class D3ilLowdimWrapper(gym.Env): new_seed = options.get( "seed", None ) # used to set all environments to specified seeds - # if self.init_state is not None: - # # always reset to the same state to be compatible with gym - # self.env.reset_to({"states": self.init_state}) if new_seed is not None: self.seed(seed=new_seed) obs = self.env.reset() @@ -60,7 +58,7 @@ class D3ilLowdimWrapper(gym.Env): # normalize obs = self.normalize_obs(obs) - return obs + return {"state": obs} def normalize_obs(self, obs): return 2 * ((obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5) @@ -75,7 +73,7 @@ class D3ilLowdimWrapper(gym.Env): # normalize obs = self.normalize_obs(obs) - return obs, reward, done, info + return {"state": obs}, reward, done, info def render(self, mode="rgb_array"): h, w = self.render_hw diff --git a/env/gym_utils/wrapper/furniture.py b/env/gym_utils/wrapper/furniture.py index 0c5776d..4cf7c16 100644 --- a/env/gym_utils/wrapper/furniture.py +++ b/env/gym_utils/wrapper/furniture.py @@ -71,8 +71,7 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper): nobs = self.process_obs(obs) self.best_reward = torch.zeros(self.env.num_envs).to(self.device) self.done = list() - - return nobs + return {"state": nobs} def reset_arg(self, options_list=None): return self.reset() @@ -80,7 +79,6 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper): def reset_one_arg(self, env_ind=None, options=None): if env_ind is not None: env_ind = torch.tensor([env_ind], device=self.device) - return self.reset() def step(self, action: np.ndarray): @@ -109,7 +107,7 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper): nobs: np.ndarray = self.process_obs(obs) done: np.ndarray = done.squeeze().cpu().numpy() - return (nobs, reward, done, info) + return {"state": nobs}, reward, done, info def _inner_step(self, action_chunk: torch.Tensor): dones = torch.zeros( diff --git a/env/gym_utils/wrapper/mujoco_locomotion_lowdim.py b/env/gym_utils/wrapper/mujoco_locomotion_lowdim.py index 7b2247f..1c0dc3d 100644 --- a/env/gym_utils/wrapper/mujoco_locomotion_lowdim.py +++ b/env/gym_utils/wrapper/mujoco_locomotion_lowdim.py @@ -1,10 +1,12 @@ """ Environment wrapper for Gym environments (MuJoCo locomotion tasks) with state observations. +For consistency, we will use Dict{} for the observation space, with the key "state" for the state observation. """ import numpy as np import gym +from gym import spaces class MujocoLocomotionLowdimWrapper(gym.Env): @@ -17,13 +19,23 @@ class MujocoLocomotionLowdimWrapper(gym.Env): # setup spaces self.action_space = env.action_space - self.observation_space = env.observation_space normalization = np.load(normalization_path) self.obs_min = normalization["obs_min"] self.obs_max = normalization["obs_max"] self.action_min = normalization["action_min"] self.action_max = normalization["action_max"] + self.observation_space = spaces.Dict() + obs_example = self.env.reset() + low = np.full_like(obs_example, fill_value=-1) + high = np.full_like(obs_example, fill_value=1) + self.observation_space["state"] = spaces.Box( + low=low, + high=high, + shape=low.shape, + dtype=low.dtype, + ) + def seed(self, seed=None): if seed is not None: np.random.seed(seed=seed) @@ -40,7 +52,7 @@ class MujocoLocomotionLowdimWrapper(gym.Env): # normalize obs = self.normalize_obs(raw_obs) - return obs + return {"state": obs} def normalize_obs(self, obs): return 2 * ((obs - self.obs_min) / (self.obs_max - self.obs_min + 1e-6) - 0.5) @@ -55,7 +67,7 @@ class MujocoLocomotionLowdimWrapper(gym.Env): # normalize obs = self.normalize_obs(raw_obs) - return obs, reward, done, info + return {"state": obs}, reward, done, info def render(self, **kwargs): return self.env.render() diff --git a/env/gym_utils/wrapper/multi_step.py b/env/gym_utils/wrapper/multi_step.py index d3b031d..758caa8 100644 --- a/env/gym_utils/wrapper/multi_step.py +++ b/env/gym_utils/wrapper/multi_step.py @@ -3,6 +3,7 @@ Multi-step wrapper. Allow executing multiple environmnt steps. Returns stacked o Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/gym_util/multistep_wrapper.py +TODO: allow cond_steps != img_cond_steps (should be implemented in training scripts, not here) """ import gym @@ -11,8 +12,6 @@ from gym import spaces import numpy as np from collections import defaultdict, deque -# import dill - def stack_repeated(x, n): return np.repeat(np.expand_dims(x, axis=0), n, axis=0) @@ -157,12 +156,11 @@ class MultiStep(gym.Wrapper): done = True self.done.append(done) self._add_info(info) - observation = self._get_obs(self.n_obs_steps) reward = aggregate(self.reward, self.reward_agg_method) done = aggregate(self.done, "max") info = dict_take_last_n(self.info, self.n_obs_steps) - if self.pass_full_observations: # right now this assume n_obs_steps = 1 + if self.pass_full_observations: info["full_obs"] = self._get_obs(act_step + 1) # In mujoco case, done can happen within the loop above @@ -206,22 +204,6 @@ class MultiStep(gym.Wrapper): """Not the best design""" return self.env.render(**kwargs) - # def get_rewards(self): - # return self.reward - - # def get_attr(self, name): - # return getattr(self, name) - - # def run_dill_function(self, dill_fn): - # fn = dill.loads(dill_fn) - # return fn(self) - - # def get_infos(self): - # result = dict() - # for k, v in self.info.items(): - # result[k] = list(v) - # return result - if __name__ == "__main__": import os diff --git a/env/gym_utils/wrapper/robomimic_image.py b/env/gym_utils/wrapper/robomimic_image.py index aa078ea..32dac4b 100644 --- a/env/gym_utils/wrapper/robomimic_image.py +++ b/env/gym_utils/wrapper/robomimic_image.py @@ -90,9 +90,7 @@ class RobomimicImageWrapper(gym.Env): action = (action + 1) / 2 # [-1, 1] -> [0, 1] return action * (self.action_max - self.action_min) + self.action_min - def get_observation(self, raw_obs=None): - if raw_obs is None: - raw_obs = self.env.get_observation() + def get_observation(self, raw_obs): obs = {"rgb": None, "state": None} # stack rgb if multiple cameras for key in self.obs_keys: if key in self.image_keys: diff --git a/env/gym_utils/wrapper/robomimic_lowdim.py b/env/gym_utils/wrapper/robomimic_lowdim.py index adbb43f..3e7378a 100644 --- a/env/gym_utils/wrapper/robomimic_lowdim.py +++ b/env/gym_utils/wrapper/robomimic_lowdim.py @@ -3,11 +3,12 @@ Environment wrapper for Robomimic environments with state observations. Modified from https://github.com/real-stanford/diffusion_policy/blob/main/diffusion_policy/env/robomimic/robomimic_lowdim_wrapper.py +For consistency, we will use Dict{} for the observation space, with the key "state" for the state observation. """ import numpy as np import gym -from gym.spaces import Box +from gym import spaces import imageio @@ -28,7 +29,6 @@ class RobomimicLowdimWrapper(gym.Env): render_camera_name="agentview", ): self.env = env - self.obs_keys = low_dim_keys self.init_state = init_state self.render_hw = render_hw self.render_camera_name = render_camera_name @@ -44,19 +44,24 @@ class RobomimicLowdimWrapper(gym.Env): self.action_min = normalization["action_min"] self.action_max = normalization["action_max"] - # setup spaces - use [-1, 1] + # setup spaces low = np.full(env.action_dimension, fill_value=-1) high = np.full(env.action_dimension, fill_value=1) - self.action_space = Box( + self.action_space = gym.spaces.Box( low=low, high=high, shape=low.shape, dtype=low.dtype, ) - obs_example = self.get_observation() + self.obs_keys = low_dim_keys + self.observation_space = spaces.Dict() + obs_example_full = self.env.get_observation() + obs_example = np.concatenate( + [obs_example_full[key] for key in self.obs_keys], axis=0 + ) low = np.full_like(obs_example, fill_value=-1) high = np.full_like(obs_example, fill_value=1) - self.observation_space = Box( + self.observation_space["state"] = spaces.Box( low=low, high=high, shape=low.shape, @@ -75,12 +80,11 @@ class RobomimicLowdimWrapper(gym.Env): action = (action + 1) / 2 # [-1, 1] -> [0, 1] return action * (self.action_max - self.action_min) + self.action_min - def get_observation(self): - raw_obs = self.env.get_observation() - raw_obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) + def get_observation(self, raw_obs): + obs = {"state": np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)} if self.normalize: - return self.normalize_obs(raw_obs) - return raw_obs + obs["state"] = self.normalize_obs(obs["state"]) + return obs def seed(self, seed=None): if seed is not None: @@ -90,7 +94,6 @@ class RobomimicLowdimWrapper(gym.Env): def reset(self, options={}, **kwargs): """Ignore passed-in arguments like seed""" - # Close video if exists if self.video_writer is not None: self.video_writer.close() @@ -106,24 +109,20 @@ class RobomimicLowdimWrapper(gym.Env): ) # used to set all environments to specified seeds if self.init_state is not None: # always reset to the same state to be compatible with gym - self.env.reset_to({"states": self.init_state}) + raw_obs = self.env.reset_to({"states": self.init_state}) elif new_seed is not None: self.seed(seed=new_seed) - self.env.reset() + raw_obs = self.env.reset() else: # random reset - self.env.reset() - return self.get_observation() + raw_obs = self.env.reset() + return self.get_observation(raw_obs) def step(self, action): if self.normalize: action = self.unnormalize_action(action) raw_obs, reward, done, info = self.env.step(action) - raw_obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) - if self.normalize: - obs = self.normalize_obs(raw_obs) - else: - obs = raw_obs + obs = self.get_observation(raw_obs) # render if specified if self.video_writer is not None: diff --git a/model/common/critic.py b/model/common/critic.py index dc59265..e425bd3 100644 --- a/model/common/critic.py +++ b/model/common/critic.py @@ -3,6 +3,7 @@ Critic networks. """ +from typing import Union import torch import copy import einops @@ -17,7 +18,7 @@ class CriticObs(torch.nn.Module): def __init__( self, - obs_dim, + cond_dim, mlp_dims, activation_type="Mish", use_layernorm=False, @@ -25,7 +26,7 @@ class CriticObs(torch.nn.Module): **kwargs, ): super().__init__() - mlp_dims = [obs_dim] + mlp_dims + [1] + mlp_dims = [cond_dim] + mlp_dims + [1] if residual_style: self.Q1 = ResidualMLP( mlp_dims, @@ -42,9 +43,20 @@ class CriticObs(torch.nn.Module): verbose=False, ) - def forward(self, x): - x = x.view(x.size(0), -1) - q1 = self.Q1(x) + def forward(self, cond: Union[dict, torch.Tensor]): + """ + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + or (B, num_feature) from ViT encoder + """ + if isinstance(cond, dict): + B = len(cond["state"]) + + # flatten history + state = cond["state"].view(B, -1) + else: + state = cond + q1 = self.Q1(state) return q1 @@ -53,7 +65,7 @@ class CriticObsAct(torch.nn.Module): def __init__( self, - obs_dim, + cond_dim, mlp_dims, action_dim, action_steps=1, @@ -63,7 +75,7 @@ class CriticObsAct(torch.nn.Module): **kwargs, ): super().__init__() - mlp_dims = [obs_dim + action_dim * action_steps] + mlp_dims + [1] + mlp_dims = [cond_dim + action_dim * action_steps] + mlp_dims + [1] if residual_tyle: self.Q1 = ResidualMLP( mlp_dims, @@ -81,9 +93,21 @@ class CriticObsAct(torch.nn.Module): ) self.Q2 = copy.deepcopy(self.Q1) - def forward(self, x, action): - x = x.view(x.size(0), -1) - x = torch.cat((x, action), dim=-1) + def forward(self, cond: dict, action): + """ + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + action: (B, Ta, Da) + """ + B = len(cond["state"]) + + # flatten history + state = cond["state"].view(B, -1) + + # flatten action + action = action.view(B, -1) + + x = torch.cat((state, action), dim=-1) q1 = self.Q1(x) q2 = self.Q2(x) return q1.squeeze(1), q2.squeeze(1) @@ -95,7 +119,8 @@ class ViTCritic(CriticObs): def __init__( self, backbone, - obs_dim, + cond_dim, + img_cond_steps=1, spatial_emb=128, patch_repr_dim=128, dropout=0, @@ -104,14 +129,16 @@ class ViTCritic(CriticObs): **kwargs, ): # update input dim to mlp - mlp_obs_dim = spatial_emb * num_img + obs_dim - super().__init__(obs_dim=mlp_obs_dim, **kwargs) + mlp_obs_dim = spatial_emb * num_img + cond_dim + super().__init__(cond_dim=mlp_obs_dim, **kwargs) self.backbone = backbone + self.num_img = num_img + self.img_cond_steps = img_cond_steps if num_img > 1: self.compress1 = SpatialEmb( num_patch=121, # TODO: repr_dim // patch_repr_dim, patch_dim=patch_repr_dim, - prop_dim=obs_dim, + prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, ) @@ -120,7 +147,7 @@ class ViTCritic(CriticObs): self.compress = SpatialEmb( num_patch=121, patch_dim=patch_repr_dim, - prop_dim=obs_dim, + prop_dim=cond_dim, proj_dim=spatial_emb, dropout=dropout, ) @@ -130,23 +157,39 @@ class ViTCritic(CriticObs): def forward( self, - obs: dict, + cond: dict, no_augment=False, ): - # flatten cond_dim if exists - if obs["rgb"].ndim == 5: - rgb = einops.rearrange(obs["rgb"], "b d c h w -> (b d) c h w") + """ + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + no_augment: whether to skip augmentation + + TODO long term: more flexible handling of cond + """ + B, T_rgb, C, H, W = cond["rgb"].shape + + # flatten history + state = cond["state"].view(B, -1) + + # Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio) + rgb = cond["rgb"][:, -self.img_cond_steps :] + + # concatenate images in cond by channels + if self.num_img > 1: + rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W) + rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w") else: - rgb = obs["rgb"] - if obs["state"].ndim == 3: - state = einops.rearrange(obs["state"], "b d c -> (b d) c") - else: - state = obs["state"] + rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w") + + # convert rgb to float32 for augmentation + rgb = rgb.float() # get vit output - pass in two images separately - if rgb.shape[1] == 6: # TODO: properly handle multiple images - rgb1 = rgb[:, :3] - rgb2 = rgb[:, 3:] + if self.num_img > 1: # TODO: properly handle multiple images + rgb1 = rgb[:, 0] + rgb2 = rgb[:, 1] if self.augment and not no_augment: rgb1 = self.aug(rgb1) rgb2 = self.aug(rgb2) diff --git a/model/common/gaussian.py b/model/common/gaussian.py index 2ce7fb6..42bbc9e 100644 --- a/model/common/gaussian.py +++ b/model/common/gaussian.py @@ -22,7 +22,6 @@ class GaussianModel(torch.nn.Module): super().__init__() self.device = device self.network = network.to(device) - self.horizon_steps = horizon_steps if network_path is not None: checkpoint = torch.load( network_path, map_location=self.device, weights_only=True @@ -35,15 +34,15 @@ class GaussianModel(torch.nn.Module): log.info( f"Number of network parameters: {sum(p.numel() for p in self.parameters())}" ) + self.horizon_steps = horizon_steps - def loss(self, true_action, cond, ent_coef): + def loss( + self, + true_action, + cond, + ent_coef, + ): B = len(true_action) - if isinstance( - cond, dict - ): # image and state, only using one step observation right now - cond = cond[0] - else: - cond = cond[0].reshape(B, -1) dist = self.forward_train( cond, deterministic=False, @@ -79,10 +78,7 @@ class GaussianModel(torch.nn.Module): randn_clip_value=10, network_override=None, ): - if isinstance(cond, dict): - B = cond["state"].shape[0] - else: - B = cond.shape[0] + B = len(cond["state"]) if "state" in cond else len(cond["rgb"]) T = self.horizon_steps dist = self.forward_train( cond, diff --git a/model/common/gmm.py b/model/common/gmm.py index fd42666..cc66cf8 100644 --- a/model/common/gmm.py +++ b/model/common/gmm.py @@ -16,20 +16,34 @@ class GMMModel(torch.nn.Module): self, network, horizon_steps, + network_path=None, device="cuda:0", **kwargs, ): super().__init__() self.device = device self.network = network.to(device) - self.horizon_steps = horizon_steps + if network_path is not None: + checkpoint = torch.load( + network_path, map_location=self.device, weights_only=True + ) + self.load_state_dict( + checkpoint["model"], + strict=False, + ) + logging.info("Loaded actor from %s", network_path) log.info( f"Number of network parameters: {sum(p.numel() for p in self.parameters())}" ) + self.horizon_steps = horizon_steps - def loss(self, true_action, obs_cond, **kwargs): + def loss( + self, + true_action, + cond, + **kwargs, + ): B = len(true_action) - cond = obs_cond[0].reshape(B, -1) dist, entropy, _ = self.forward_train( cond, deterministic=False, @@ -72,7 +86,7 @@ class GMMModel(torch.nn.Module): return dist, approx_entropy, std def forward(self, cond, deterministic=False): - B = cond.shape[0] + B = len(cond["state"]) if "state" in cond else len(cond["rgb"]) T = self.horizon_steps dist, _, _ = self.forward_train( cond, diff --git a/model/common/mlp_gaussian.py b/model/common/mlp_gaussian.py index b5ea047..73adb16 100644 --- a/model/common/mlp_gaussian.py +++ b/model/common/mlp_gaussian.py @@ -21,6 +21,7 @@ class Gaussian_VisionMLP(nn.Module): transition_dim, horizon_steps, cond_dim, + img_cond_steps=1, mlp_dims=[256, 256, 256], activation_type="Mish", residual_style=False, @@ -44,6 +45,8 @@ class Gaussian_VisionMLP(nn.Module): if augment: self.aug = RandomShiftsAug(pad=4) self.augment = augment + self.num_img = num_img + self.img_cond_steps = img_cond_steps if spatial_emb > 0: assert spatial_emb > 1, "this is the dimension" if num_img > 1: @@ -109,24 +112,31 @@ class Gaussian_VisionMLP(nn.Module): self.fixed_std = fixed_std self.learn_fixed_std = learn_fixed_std - def forward(self, x): - B = len(x["state"]) - device = x["state"].device + def forward(self, cond): + B = len(cond["rgb"]) + device = cond["rgb"].device + _, T_rgb, C, H, W = cond["rgb"].shape - # flatten cond_dim if exists - if x["rgb"].ndim == 5: - rgb = einops.rearrange(x["rgb"], "b d c h w -> (b d) c h w") + # flatten history + state = cond["state"].view(B, -1) + + # Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio) + rgb = cond["rgb"][:, -self.img_cond_steps :] + + # concatenate images in cond by channels + if self.num_img > 1: + rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W) + rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w") else: - rgb = x["rgb"] - if x["state"].ndim == 3: - state = einops.rearrange(x["state"], "b d c -> (b d) c") - else: - state = x["state"] + rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w") + + # convert rgb to float32 for augmentation + rgb = rgb.float() # get vit output - pass in two images separately - if rgb.shape[1] == 6: # TODO: properly handle multiple images - rgb1 = rgb[:, :3] - rgb2 = rgb[:, 3:] + if self.num_img > 1: # TODO: properly handle multiple images + rgb1 = rgb[:, 0] + rgb2 = rgb[:, 1] if self.augment: rgb1 = self.aug(rgb1) rgb2 = self.aug(rgb2) @@ -223,11 +233,15 @@ class Gaussian_MLP(nn.Module): self.fixed_std = fixed_std self.learn_fixed_std = learn_fixed_std - def forward(self, x): - B = len(x) + def forward(self, cond): + B = len(cond["state"]) + device = cond["state"].device + + # flatten history + state = cond["state"].view(B, -1) # mlp - out_mean = self.mlp_mean(x) + out_mean = self.mlp_mean(state) out_mean = torch.tanh(out_mean).view( B, self.horizon_steps * self.transition_dim ) # tanh squashing in [-1, 1] @@ -238,9 +252,9 @@ class Gaussian_MLP(nn.Module): out_scale = out_scale.view(1, self.transition_dim) out_scale = out_scale.repeat(B, self.horizon_steps) elif self.use_fixed_std: - out_scale = torch.ones_like(out_mean).to(x.device) * self.fixed_std + out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: - out_logvar = self.mlp_logvar(x).view( + out_logvar = self.mlp_logvar(state).view( B, self.horizon_steps * self.transition_dim ) out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max) diff --git a/model/common/mlp_gmm.py b/model/common/mlp_gmm.py index 9262d97..8844b63 100644 --- a/model/common/mlp_gmm.py +++ b/model/common/mlp_gmm.py @@ -77,11 +77,15 @@ class GMM_MLP(nn.Module): use_layernorm=use_layernorm, ) - def forward(self, x): - B = len(x) + def forward(self, cond): + B = len(cond["state"]) + device = cond["state"].device + + # flatten history + state = cond["state"].view(B, -1) # mlp - out_mean = self.mlp_mean(x) + out_mean = self.mlp_mean(state) out_mean = torch.tanh(out_mean).view( B, self.num_modes, self.horizon_steps * self.transition_dim ) # tanh squashing in [-1, 1] @@ -92,15 +96,15 @@ class GMM_MLP(nn.Module): out_scale = out_scale.view(1, self.num_modes, self.transition_dim) out_scale = out_scale.repeat(B, 1, self.horizon_steps) elif self.use_fixed_std: - out_scale = torch.ones_like(out_mean).to(x.device) * self.fixed_std + out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: - out_logvar = self.mlp_logvar(x).view( + out_logvar = self.mlp_logvar(state).view( B, self.num_modes, self.horizon_steps * self.transition_dim ) out_logvar = torch.clamp(out_logvar, self.logvar_min, self.logvar_max) out_scale = torch.exp(0.5 * out_logvar) - out_weights = self.mlp_weights(x) + out_weights = self.mlp_weights(state) out_weights = out_weights.view(B, self.num_modes) return out_mean, out_scale, out_weights diff --git a/model/common/modules.py b/model/common/modules.py index 0a72e4f..8a83be0 100644 --- a/model/common/modules.py +++ b/model/common/modules.py @@ -67,3 +67,21 @@ class RandomShiftsAug: return nn.functional.grid_sample( x, grid, padding_mode="zeros", align_corners=False ) + + +# test random shift +if __name__ == "__main__": + from PIL import Image + import requests + import numpy as np + + image_url = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_30.jpg" + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((96, 96)) + + image = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() + aug = RandomShiftsAug(pad=4) + image_aug = aug(image) + image_aug = image_aug.squeeze().permute(1, 2, 0).numpy() + image_aug = Image.fromarray(image_aug.astype(np.uint8)) + image_aug.show() diff --git a/model/common/transformer.py b/model/common/transformer.py index 653f231..1c38637 100644 --- a/model/common/transformer.py +++ b/model/common/transformer.py @@ -70,13 +70,15 @@ class Gaussian_Transformer(nn.Module): ) def forward(self, cond): - """ - cond: (B,cond_dim) - output: (B,horizon*transition) - """ - B = len(cond) - cond = cond.unsqueeze(1) # (B,1,cond_dim) - out, _ = self.transformer(cond) # (B,horizon,output_dim) + B = len(cond["state"]) + device = cond["state"].device + + # flatten history + state = cond["state"].view(B, -1) + + # input to transformer + state = state.unsqueeze(1) # (B,1,cond_dim) + out, _ = self.transformer(state) # (B,horizon,output_dim) # use the first half of the output as mean out_mean = torch.tanh(out[:, :, : self.transition_dim]) @@ -88,7 +90,7 @@ class Gaussian_Transformer(nn.Module): out_scale = out_scale.view(1, self.transition_dim) out_scale = out_scale.repeat(B, self.horizon_steps) elif self.fixed_std is not None: - out_scale = torch.ones_like(out_mean).to(cond.device) * self.fixed_std + out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: out_logvar = out[:, :, self.transition_dim :] out_logvar = out_logvar.reshape(B, self.horizon_steps * self.transition_dim) @@ -164,14 +166,16 @@ class GMM_Transformer(nn.Module): self.modes_head = nn.Linear(horizon_steps * transformer_embed_dim, num_modes) def forward(self, cond): - """ - cond: (B,cond_dim) - output: (B,horizon*transition) - """ - B = len(cond) - cond = cond.unsqueeze(1) # (B,1,cond_dim) + B = len(cond["state"]) + device = cond["state"].device + + # flatten history + state = cond["state"].view(B, -1) + + # input to transformer + state = state.unsqueeze(1) # (B,1,cond_dim) out, out_prehead = self.transformer( - cond + state ) # (B,horizon,output_dim), (B,horizon,emb_dim) # use the first half of the output as mean @@ -190,7 +194,7 @@ class GMM_Transformer(nn.Module): out_scale = out_scale.view(1, self.num_modes, self.transition_dim) out_scale = out_scale.repeat(B, 1, self.horizon_steps) elif self.fixed_std is not None: - out_scale = torch.ones_like(out_mean).to(cond.device) * self.fixed_std + out_scale = torch.ones_like(out_mean).to(device) * self.fixed_std else: out_logvar = out[ :, :, self.num_modes * self.transition_dim : -self.num_modes diff --git a/model/common/vit.py b/model/common/vit.py index 942b9bb..ec6f2ea 100644 --- a/model/common/vit.py +++ b/model/common/vit.py @@ -24,7 +24,12 @@ class VitEncoderConfig: class VitEncoder(nn.Module): - def __init__(self, obs_shape: List[int], cfg: VitEncoderConfig): + def __init__( + self, + obs_shape: List[int], + cfg: VitEncoderConfig, + num_channel=3, + ): super().__init__() self.obs_shape = obs_shape self.cfg = cfg @@ -34,6 +39,7 @@ class VitEncoder(nn.Module): embed_norm=cfg.embed_norm, num_head=cfg.num_heads, depth=cfg.depth, + num_channel=num_channel, ) self.num_patch = self.vit.num_patches @@ -50,9 +56,9 @@ class VitEncoder(nn.Module): class PatchEmbed1(nn.Module): - def __init__(self, embed_dim): + def __init__(self, embed_dim, num_channel=3): super().__init__() - self.conv = nn.Conv2d(3, embed_dim, kernel_size=8, stride=8) + self.conv = nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=8) self.num_patch = 144 self.patch_dim = embed_dim @@ -64,10 +70,10 @@ class PatchEmbed1(nn.Module): class PatchEmbed2(nn.Module): - def __init__(self, embed_dim, use_norm): + def __init__(self, embed_dim, use_norm, num_channel=3): super().__init__() layers = [ - nn.Conv2d(3, embed_dim, kernel_size=8, stride=4), + nn.Conv2d(num_channel, embed_dim, kernel_size=8, stride=4), nn.GroupNorm(embed_dim, embed_dim) if use_norm else nn.Identity(), nn.ReLU(), nn.Conv2d(embed_dim, embed_dim, kernel_size=3, stride=2), @@ -132,13 +138,23 @@ class TransformerLayer(nn.Module): class MinVit(nn.Module): - def __init__(self, embed_style, embed_dim, embed_norm, num_head, depth): + def __init__( + self, + embed_style, + embed_dim, + embed_norm, + num_head, + depth, + num_channel=3, + ): super().__init__() if embed_style == "embed1": - self.patch_embed = PatchEmbed1(embed_dim) + self.patch_embed = PatchEmbed1(embed_dim, num_channel=num_channel) elif embed_style == "embed2": - self.patch_embed = PatchEmbed2(embed_dim, use_norm=embed_norm) + self.patch_embed = PatchEmbed2( + embed_dim, use_norm=embed_norm, num_channel=num_channel + ) else: assert False @@ -217,20 +233,10 @@ def test_transformer_layer(): if __name__ == "__main__": - import rich.traceback - import pyrallis - - @dataclass - class MainConfig: - net_type: str = "vit" - obs_shape: list[int] = field(default_factory=lambda: [3, 96, 96]) - vit: VitEncoderConfig = field(default_factory=lambda: VitEncoderConfig()) - - rich.traceback.install() - cfg = pyrallis.parse(config_class=MainConfig) # type: ignore - enc = VitEncoder(cfg.obs_shape, cfg.vit) + obs_shape = [6, 96, 96] + enc = VitEncoder([6, 96, 96], VitEncoderConfig()) print(enc) - x = torch.rand(1, *cfg.obs_shape) * 255 + x = torch.rand(1, *obs_shape) * 255 print("output size:", enc(x, flatten=False).size()) print("repr dim:", enc.repr_dim, ", real dim:", enc(x, flatten=True).size()) diff --git a/model/diffusion/diffusion.py b/model/diffusion/diffusion.py index 796925c..fcc69ba 100644 --- a/model/diffusion/diffusion.py +++ b/model/diffusion/diffusion.py @@ -8,7 +8,7 @@ Annotated DDIM/DDPM: https://nn.labml.ai/diffusion/stable_diffusion/sampler/ddpm """ -from typing import Optional, Union +from typing import Union import logging import torch from torch import nn @@ -17,13 +17,12 @@ import torch.nn.functional as F log = logging.getLogger(__name__) from model.diffusion.sampling import ( - make_timesteps, extract, cosine_beta_schedule, ) from collections import namedtuple -Sample = namedtuple("Sample", "trajectories values chains") +Sample = namedtuple("Sample", "trajectories chains") class DiffusionModel(nn.Module): @@ -34,9 +33,7 @@ class DiffusionModel(nn.Module): horizon_steps, obs_dim, action_dim, - transition_dim, network_path=None, - cond_steps=1, device="cuda:0", # DDPM parameters denoising_steps=100, @@ -53,11 +50,9 @@ class DiffusionModel(nn.Module): self.horizon_steps = horizon_steps self.obs_dim = obs_dim self.action_dim = action_dim - self.transition_dim = transition_dim self.denoising_steps = int(denoising_steps) self.denoised_clip_value = denoised_clip_value self.predict_epsilon = predict_epsilon - self.cond_steps = cond_steps self.use_ddim = use_ddim self.ddim_steps = ddim_steps @@ -216,52 +211,11 @@ class DiffusionModel(nn.Module): @torch.no_grad() def forward( self, - cond: Optional[torch.Tensor], + cond, return_chain=True, + **kwargs, ): - """ - Forward sampling through denoising steps. - - Args: - cond: (batch_size, horizon, transition_dim) - return_chain: whether to return the chain of samples or only the final denoised sample - Return: - Sample: namedtuple with fields: - trajectories: (batch_size, horizon_steps, transition_dim) - values: (batch_size, ) - chain: (batch_size, denoising_steps + 1, horizon_steps, transition_dim) - """ - device = self.betas.device - if isinstance(cond, dict): - B = cond[list(cond.keys())[0]].shape[0] - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps].reshape(B, -1) - shape = (B, self.horizon_steps, self.transition_dim) - - # Loop - x = torch.randn(shape, device=device) - chain = [x] if return_chain else None - if self.use_ddim: - t_all = self.ddim_t - else: - t_all = list(reversed(range(self.denoising_steps))) - for i, t in enumerate(t_all): - t_b = make_timesteps(B, t, device) - index_b = make_timesteps(B, i, device) - mu, logvar = self.p_mean_var(x=x, t=t_b, cond=cond, index=index_b) - std = torch.exp(0.5 * logvar) - - # no noise when t == 0 - noise = torch.randn_like(x) - noise[t == 0] = 0 - x = mu + std * noise - if return_chain: - chain.append(x) - if return_chain: - chain = torch.stack(chain, dim=1) - values = torch.zeros(len(x), device=x.device) # not considering the value for now - return Sample(x, values, chain) + raise NotImplementedError # ---------- Supervised training ----------# @@ -275,23 +229,18 @@ class DiffusionModel(nn.Module): def p_losses( self, x_start, - obs_cond: Union[dict, torch.Tensor], + cond: Union[dict, torch.Tensor], t, ): """ If predicting epsilon: E_{t, x0, ε} [||ε - ε_θ(√α̅ₜx0 + √(1-α̅ₜ)ε, t)||² Args: - x_start: (batch_size, horizon_steps, transition_dim) - obs_cond: dict with keys as step and value as observation + x_start: (batch_size, horizon_steps, action_dim) + cond: dict with keys as step and value as observation t: batch of integers """ device = x_start.device - B = x_start.shape[0] - if isinstance(obs_cond[0], dict): - cond = obs_cond[0] # keep the dictionary and the network will extract img and prio - else: - cond = obs_cond[0].reshape(B, -1) # Forward process noise = torch.randn_like(x_start, device=device) diff --git a/model/diffusion/diffusion_dipo.py b/model/diffusion/diffusion_dipo.py index 4e4630c..a051ff4 100644 --- a/model/diffusion/diffusion_dipo.py +++ b/model/diffusion/diffusion_dipo.py @@ -40,19 +40,19 @@ class DIPODiffusion(DiffusionModel): # Whether to clamp sampled action between [-1, 1] self.clamp_action = clamp_action + # ---------- RL training ----------# + def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma): # get current Q-function - actions_flat = torch.flatten(actions, start_dim=-2) - current_q1, current_q2 = self.critic(obs, actions_flat) + current_q1, current_q2 = self.critic(obs, actions) # get next Q-function next_actions = self.forward( cond=next_obs, deterministic=False, - ) # in DiffusionModel, forward() has no gradient, which is desired here. - next_actions_flat = torch.flatten(next_actions, start_dim=-2) - next_q1, next_q2 = self.critic(next_obs, next_actions_flat) + ) # forward() has no gradient, which is desired here. + next_q1, next_q2 = self.critic(next_obs, next_actions) next_q = torch.min(next_q1, next_q2) # terminal state mask @@ -73,6 +73,8 @@ class DIPODiffusion(DiffusionModel): return loss_critic + # ---------- Sampling ----------#`` + # override @torch.no_grad() def forward( @@ -81,15 +83,10 @@ class DIPODiffusion(DiffusionModel): deterministic=False, ): device = self.betas.device - B = cond.shape[0] - if isinstance(cond, dict): - raise NotImplementedError("Not implemented for images") - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps] + B = len(cond["state"]) # Loop - x = torch.randn((B, self.horizon_steps, self.transition_dim), device=device) + x = torch.randn((B, self.horizon_steps, self.action_dim), device=device) t_all = list(reversed(range(self.denoising_steps))) for i, t in enumerate(t_all): t_b = make_timesteps(B, t, device) diff --git a/model/diffusion/diffusion_dql.py b/model/diffusion/diffusion_dql.py index f6ef0ed..4b53cc2 100644 --- a/model/diffusion/diffusion_dql.py +++ b/model/diffusion/diffusion_dql.py @@ -41,19 +41,19 @@ class DQLDiffusion(DiffusionModel): # Whether to clamp sampled action between [-1, 1] self.clamp_action = clamp_action + # ---------- RL training ----------# + def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma): # get current Q-function - actions_flat = torch.flatten(actions, start_dim=-2) - current_q1, current_q2 = self.critic(obs, actions_flat) + current_q1, current_q2 = self.critic(obs, actions) # get next Q-function next_actions = self.forward( cond=next_obs, deterministic=False, - ) # in DiffusionModel, forward() has no gradient, which is desired here. - next_actions_flat = torch.flatten(next_actions, start_dim=-2) - next_q1, next_q2 = self.critic(next_obs, next_actions_flat) + ) # forward() has no gradient, which is desired here. + next_q1, next_q2 = self.critic(next_obs, next_actions) next_q = torch.min(next_q1, next_q2) # terminal state mask @@ -75,7 +75,7 @@ class DQLDiffusion(DiffusionModel): return loss_critic def loss_actor(self, obs, actions, q1, q2, eta): - bc_loss = self.loss(actions, {0: obs}) + bc_loss = self.loss(actions, obs) if np.random.uniform() > 0.5: q_loss = -q1.mean() / q2.abs().mean().detach() else: @@ -83,6 +83,8 @@ class DQLDiffusion(DiffusionModel): actor_loss = bc_loss + eta * q_loss return actor_loss + # ---------- Sampling ----------#`` + # override @torch.no_grad() def forward( @@ -91,15 +93,10 @@ class DQLDiffusion(DiffusionModel): deterministic=False, ): device = self.betas.device - B = cond.shape[0] - if isinstance(cond, dict): - raise NotImplementedError("Not implemented for images") - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps] + B = len(cond["state"]) # Loop - x = torch.randn((B, self.horizon_steps, self.transition_dim), device=device) + x = torch.randn((B, self.horizon_steps, self.action_dim), device=device) t_all = list(reversed(range(self.denoising_steps))) for i, t in enumerate(t_all): t_b = make_timesteps(B, t, device) @@ -136,15 +133,10 @@ class DQLDiffusion(DiffusionModel): Differentiable forward pass used in actor training. """ device = self.betas.device - B = cond.shape[0] - if isinstance(cond, dict): - raise NotImplementedError("Not implemented for images") - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps] + B = len(cond["state"]) # Loop - x = torch.randn((B, self.horizon_steps, self.transition_dim), device=device) + x = torch.randn((B, self.horizon_steps, self.action_dim), device=device) t_all = list(reversed(range(self.denoising_steps))) for i, t in enumerate(t_all): t_b = make_timesteps(B, t, device) diff --git a/model/diffusion/diffusion_idql.py b/model/diffusion/diffusion_idql.py index 121dc26..8a5b917 100644 --- a/model/diffusion/diffusion_idql.py +++ b/model/diffusion/diffusion_idql.py @@ -42,12 +42,13 @@ class IDQLDiffusion(RWRDiffusion): # assign actor self.actor = self.network + # ---------- RL training ----------# + def compute_advantages(self, obs, actions): - # get current Q-function - actions_flat = torch.flatten(actions, start_dim=-2) - with torch.no_grad(): # no gradients for q-function when we update value function - current_q1, current_q2 = self.target_q(obs, actions_flat) + # get current Q-function, stop gradient + with torch.no_grad(): + current_q1, current_q2 = self.target_q(obs, actions) q = torch.min(current_q1, current_q2) # get the current V-function @@ -59,7 +60,6 @@ class IDQLDiffusion(RWRDiffusion): return adv def loss_critic_v(self, obs, actions): - adv = self.compute_advantages(obs, actions) # get the value loss @@ -70,11 +70,10 @@ class IDQLDiffusion(RWRDiffusion): def loss_critic_q(self, obs, next_obs, actions, rewards, dones, gamma): # get current Q-function - actions_flat = torch.flatten(actions, start_dim=-2) - current_q1, current_q2 = self.critic_q(obs, actions_flat) + current_q1, current_q2 = self.critic_q(obs, actions) - # get the next V-function - with torch.no_grad(): # no gradients for value function when we update q function + # get the next V-function, stop gradient + with torch.no_grad(): next_v = self.critic_v(next_obs) # terminal state mask @@ -98,8 +97,35 @@ class IDQLDiffusion(RWRDiffusion): def update_target_critic(self, tau): soft_update(self.target_q, self.critic_q, tau) + # override + def p_losses( + self, + x_start, + cond, + t, + ): + device = x_start.device + + # Forward process + noise = torch.randn_like(x_start, device=device) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + + # Predict + x_recon = self.network(x_noisy, t, cond=cond) + + # Loss with mask + if self.predict_epsilon: + loss = F.mse_loss(x_recon, noise, reduction="none") + else: + loss = F.mse_loss(x_recon, x_start, reduction="none") + loss = einops.reduce(loss, "b h d -> b", "mean") + return loss.mean() + + # ---------- Sampling ----------#`` + + # override @torch.no_grad() - def forward( # override + def forward( self, cond, deterministic=False, @@ -107,23 +133,23 @@ class IDQLDiffusion(RWRDiffusion): critic_hyperparam=0.7, # sampling weight for implicit policy use_expectile_exploration=True, ): + """assume state-only, no rgb in cond""" # repeat obs num_sample times along dim 0 - cond_shape_repeat_dims = tuple(1 for _ in cond.shape) - B, T, D = cond.shape + cond_shape_repeat_dims = tuple(1 for _ in cond["state"].shape) + B, T, D = cond["state"].shape S = num_sample - cond_repeat = cond[None].repeat(num_sample, *cond_shape_repeat_dims) + cond_repeat = cond["state"][None].repeat(num_sample, *cond_shape_repeat_dims) cond_repeat = cond_repeat.view(-1, T, D) # [B*S, T, D] # for eval, use less noisy samples --- there is still DDPM noise, but final action uses small min_sampling_std samples = super(IDQLDiffusion, self).forward( - cond_repeat, + {"state": cond_repeat}, deterministic=deterministic, ) _, H, A = samples.shape # get current Q-function - actions_flat = torch.flatten(samples, start_dim=-2) - current_q1, current_q2 = self.target_q(cond_repeat, actions_flat) + current_q1, current_q2 = self.target_q({"state": cond_repeat}, samples) q = torch.min(current_q1, current_q2) q = q.view(S, B) @@ -141,7 +167,7 @@ class IDQLDiffusion(RWRDiffusion): # Sample as an implicit policy for exploration else: # get the current value function for probabilistic exploration - current_v = self.critic_v(cond_repeat) + current_v = self.critic_v({"state": cond_repeat}) v = current_v.view(S, B) adv = q - v @@ -164,34 +190,3 @@ class IDQLDiffusion(RWRDiffusion): # squeeze dummy dimension samples = samples_best[0] return samples - - # override - def p_losses( - self, - x_start, - obs_cond, - t, - ): - device = x_start.device - B, T, D = x_start.shape - - # handle different ways of passing observation - if isinstance(obs_cond[0], dict): - cond = obs_cond[0] - else: - cond = obs_cond.reshape(B, -1) - - # Forward process - noise = torch.randn_like(x_start, device=device) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - - # Predict - x_recon = self.network(x_noisy, t, cond=cond) - - # Loss with mask - if self.predict_epsilon: - loss = F.mse_loss(x_recon, noise, reduction="none") - else: - loss = F.mse_loss(x_recon, x_start, reduction="none") - loss = einops.reduce(loss, "b h d -> b", "mean") - return loss.mean() diff --git a/model/diffusion/diffusion_ppo.py b/model/diffusion/diffusion_ppo.py index 0e8081f..9c13863 100644 --- a/model/diffusion/diffusion_ppo.py +++ b/model/diffusion/diffusion_ppo.py @@ -1,6 +1,15 @@ """ DPPO: Diffusion Policy Policy Optimization. +K: number of denoising steps +To: observation sequence length +Ta: action chunk size +Do: observation dimension +Da: action dimension + +C: image channels +H, W: image height and width + """ from typing import Optional @@ -60,13 +69,15 @@ class PPODiffusion(VPGDiffusion): """ PPO loss - obs: (B, obs_step, obs_dim) - chains: (B, num_denoising_step+1, horizon_step, action_dim) + obs: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + chains: (B, K+1, Ta, Da) returns: (B, ) values: (B, ) advantages: (B,) - oldlogprobs: (B, num_denoising_step, horizon_step, action_dim) - use_bc_loss: add BC regularization loss + oldlogprobs: (B, K, Ta, Da) + use_bc_loss: whether to add BC regularization loss reward_horizon: action horizon that backpropagates gradient """ # Get new logprobs for denoising steps from T-1 to 0 - entropy is fixed fod diffusion diff --git a/model/diffusion/diffusion_ppo_exact.py b/model/diffusion/diffusion_ppo_exact.py index 6d442a8..cd7858c 100644 --- a/model/diffusion/diffusion_ppo_exact.py +++ b/model/diffusion/diffusion_ppo_exact.py @@ -3,6 +3,11 @@ Diffusion policy gradient with exact likelihood estimation. Based on score_sde_pytorch https://github.com/yang-song/score_sde_pytorch +To: observation sequence length +Ta: action chunk size +Do: observation dimension +Da: action dimension + """ import torch @@ -52,13 +57,12 @@ class PPOExactDiffusion(PPODiffusion): num_epsilon=sde_num_epsilon, ) - def get_exact_logprobs(self, obs, samples): + def get_exact_logprobs(self, cond, samples): """Use torchdiffeq - samples: B x horizon x transition_dim + samples: (B x Ta x Da) """ # TODO: image input - cond = obs.reshape(-1, self.obs_dim) return self.likelihood_fn( self.actor, self.actor_ft, @@ -79,6 +83,17 @@ class PPOExactDiffusion(PPODiffusion): use_bc_loss=False, **kwargs, ): + """ + PPO loss + + obs: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + samples: (B, Ta, Da) + returns: (B, ) + values: (B, ) + advantages: (B,) + oldlogprobs: (B, ) + """ # Get new logprobs for final x newlogprobs = self.get_exact_logprobs(obs, samples) newlogprobs = newlogprobs.clamp(min=-5, max=2) diff --git a/model/diffusion/diffusion_qsm.py b/model/diffusion/diffusion_qsm.py index 0b3e808..3b833b7 100644 --- a/model/diffusion/diffusion_qsm.py +++ b/model/diffusion/diffusion_qsm.py @@ -39,11 +39,12 @@ class QSMDiffusion(RWRDiffusion): # assign actor self.actor = self.network + # ---------- RL training ----------# + def loss_actor(self, obs, actions, q_grad_coeff): x_start = actions device = x_start.device - B, T, D = x_start.shape - cond = obs.reshape(B, -1) + B = len(x_start) # Forward process noise = torch.randn_like(x_start, device=device) @@ -53,39 +54,35 @@ class QSMDiffusion(RWRDiffusion): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # get current value for noisy actions as the code does --- the algorthm block in the paper is wrong, it says using a_t, the final denoised action - x_noisy_flat = torch.flatten(x_noisy, start_dim=-2) - x_noisy_flat.requires_grad_(True) - current_q1, current_q2 = self.critic_q(obs, x_noisy_flat) + # x_noisy_flat = torch.flatten(x_noisy, start_dim=-2) + x_noisy.requires_grad_(True) + current_q1, current_q2 = self.critic_q(obs, x_noisy) # Compute dQ/da|a=noise_actions - gradient_q1 = torch.autograd.grad(current_q1.sum(), x_noisy_flat)[0] - gradient_q2 = torch.autograd.grad(current_q2.sum(), x_noisy_flat)[0] + gradient_q1 = torch.autograd.grad(current_q1.sum(), x_noisy)[0] + gradient_q2 = torch.autograd.grad(current_q2.sum(), x_noisy)[0] gradient_q = torch.stack((gradient_q1, gradient_q2), 0).mean(0).detach() # Predict noise from noisy actions - x_recon = self.network(x_noisy, t, cond=cond) - x_recon = torch.flatten(x_recon, start_dim=-2) + x_recon = self.network(x_noisy, t, cond=obs) # Loss with mask - align predicted noise with critic gradient of noisy actions # Note: the gradient of mu wrt. epsilon has a negative sign loss = F.mse_loss(-x_recon, q_grad_coeff * gradient_q, reduction="none").mean() - return loss def loss_critic(self, obs, next_obs, actions, rewards, dones, gamma): # get current Q-function - actions_flat = torch.flatten(actions, start_dim=-2) - current_q1, current_q2 = self.critic_q(obs, actions_flat) + current_q1, current_q2 = self.critic_q(obs, actions) # get next Q-function - with noise, same as QSM https://github.com/Alescontrela/score_matching_rl/blob/f02a21969b17e322eb229ceb2b0f5a9111b1b968/jaxrl5/agents/score_matching/score_matching_learner.py#L193 next_actions = self.forward( cond=next_obs, deterministic=False, - ) # in DiffusionModel, forward() has no gradient, which is desired here. - next_actions_flat = torch.flatten(next_actions, start_dim=-2) + ) # forward() has no gradient, which is desired here. with torch.no_grad(): - next_q1, next_q2 = self.target_q(next_obs, next_actions_flat) + next_q1, next_q2 = self.target_q(next_obs, next_actions) next_q = torch.min(next_q1, next_q2) # terminal state mask diff --git a/model/diffusion/diffusion_rwr.py b/model/diffusion/diffusion_rwr.py index d60cc65..8bad230 100644 --- a/model/diffusion/diffusion_rwr.py +++ b/model/diffusion/diffusion_rwr.py @@ -43,18 +43,11 @@ class RWRDiffusion(DiffusionModel): def p_losses( self, x_start, - obs_cond, + cond, rewards, t, ): device = x_start.device - B, T, D = x_start.shape - - # handle different ways of passing observation - if isinstance(obs_cond[0], dict): - cond = obs_cond[0] - else: - cond = obs_cond.reshape(B, -1) # Forward process noise = torch.randn_like(x_start, device=device) @@ -79,7 +72,7 @@ class RWRDiffusion(DiffusionModel): self, x, t, - cond=None, + cond, ): noise = self.network(x, t, cond=cond) @@ -116,15 +109,10 @@ class RWRDiffusion(DiffusionModel): deterministic=False, ): device = self.betas.device - B = cond.shape[0] - if isinstance(cond, dict): - raise NotImplementedError("Not implemented for images") - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps] + B = len(cond["state"]) # Loop - x = torch.randn((B, self.horizon_steps, self.transition_dim), device=device) + x = torch.randn((B, self.horizon_steps, self.action_dim), device=device) t_all = list(reversed(range(self.denoising_steps))) for i, t in enumerate(t_all): t_b = make_timesteps(B, t, device) diff --git a/model/diffusion/diffusion_vpg.py b/model/diffusion/diffusion_vpg.py index c0b17eb..809148c 100644 --- a/model/diffusion/diffusion_vpg.py +++ b/model/diffusion/diffusion_vpg.py @@ -1,14 +1,20 @@ """ -Policy gradient with diffusion policy. +Policy gradient with diffusion policy. VPG: vanilla policy gradient -VPG: vanilla policy gradient +K: number of denoising steps +To: observation sequence length +Ta: action chunk size +Do: observation dimension +Da: action dimension + +C: image channels +H, W: image height and width """ import copy import torch import logging -import einops log = logging.getLogger(__name__) import torch.nn.functional as F @@ -106,7 +112,11 @@ class VPGDiffusion(DiffusionModel): # ---------- Sampling ----------# def step(self): - """Update min_sampling_denoising_std annealing and fine-tuning denoising steps annealing. Both not used currently""" + """ + Anneal min_sampling_denoising_std and fine-tuning denoising steps + + Current configs do not apply annealing + """ # anneal min_sampling_denoising_std if type(self.min_sampling_denoising_std) is not float: self.min_sampling_denoising_std.step() @@ -142,7 +152,7 @@ class VPGDiffusion(DiffusionModel): self, x, t, - cond=None, + cond, index=None, use_base_policy=False, deterministic=False, @@ -160,12 +170,8 @@ class VPGDiffusion(DiffusionModel): # overwrite noise for fine-tuning steps if len(ft_indices) > 0: - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][ft_indices] for key in cond} - else: - cond = cond[ft_indices] - noise_ft = actor(x[ft_indices], t[ft_indices], cond=cond) + cond_ft = {key: cond[key][ft_indices] for key in cond} + noise_ft = actor(x[ft_indices], t[ft_indices], cond=cond_ft) noise[ft_indices] = noise_ft # Predict x_0 @@ -208,7 +214,8 @@ class VPGDiffusion(DiffusionModel): if deterministic: etas = torch.zeros((x.shape[0], 1, 1)).to(x.device) else: - etas = self.eta(cond).unsqueeze(1) # B x 1 x (transition_dim or 1) + # TODO: eta cond + etas = self.eta(cond).unsqueeze(1) # B x 1 x (Da or 1) sigma = ( etas * ((1 - alpha_prev) / (1 - alpha) * (1 - alpha / alpha_prev)) ** 0.5 @@ -242,30 +249,26 @@ class VPGDiffusion(DiffusionModel): Forward pass for sampling actions. Args: - cond: (batch_size, obs_step, obs_dim) - deterministic: whether to sample deterministically - return_chain: whether to return the chain of samples - use_base_policy: whether to use the base policy instead + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + deterministic: If true, then std=0 with DDIM, or with DDPM, use normal schedule (instead of clipping at a higher value) + return_chain: whether to return the entire chain of denoised actions + use_base_policy: whether to use the frozen pre-trained policy instead Return: Sample: namedtuple with fields: - trajectories: (batch_size, horizon_steps, transition_dim) - values: (batch_size, ) - chain: (batch_size, denoising_steps + 1, horizon_steps, transition_dim) + trajectories: (B, Ta, Da) + chain: (B, K + 1, Ta, Da) """ device = self.betas.device - if isinstance(cond, dict): - B = cond["state"].shape[0] - cond["state"] = cond["state"][:, : self.cond_steps] - cond["rgb"] = cond["rgb"][:, : self.cond_steps] - else: - B = cond.shape[0] - cond = cond[:, : self.cond_steps] + sample_data = cond["state"] if "state" in cond else cond["rgb"] + B = len(sample_data) # Get updated minimum sampling denoising std min_sampling_denoising_std = self.get_min_sampling_denoising_std() # Loop - x = torch.randn((B, self.horizon_steps, self.transition_dim), device=device) + x = torch.randn((B, self.horizon_steps, self.action_dim), device=device) if self.use_ddim: t_all = self.ddim_t else: @@ -317,17 +320,16 @@ class VPGDiffusion(DiffusionModel): self.ddim_steps - self.ft_denoising_steps - 1 ): chain.append(x) - values = torch.zeros(len(x), device=x.device) if return_chain: chain = torch.stack(chain, dim=1) - return Sample(x, values, chain) + return Sample(x, chain) # ---------- RL training ----------# def get_logprobs( self, - obs, + cond, chains, get_ent: bool = False, use_base_policy: bool = False, @@ -336,35 +338,25 @@ class VPGDiffusion(DiffusionModel): Calculating the logprobs of the entire chain of denoised actions. Args: - obs: (B, obs_step, obs_dim) - chains: (B, num_denoising_step+1, horizon_step, action_dim) - get_ent: flag for returning entropy - use_base_policy: flag for using base policy + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + chains: (B, K+1, Ta, Da) + get_ent: flag for returning entropy + use_base_policy: flag for using base policy Returns: - logprobs: (B x num_denoising_steps, horizon_step, action_dim) - entropy (if get_ent=True): (B x num_denoising_steps, horizon_step) + logprobs: (B x K, Ta, Da) + entropy (if get_ent=True): (B x K, Ta) """ - # Repeat obs conditioning for denoising_steps - if isinstance(obs, dict): - obs = { - key: obs[key] - .unsqueeze(1) - .repeat(1, self.ft_denoising_steps, *(1,) * (obs[key].ndim - 1)) - for key in obs - } - else: - obs = einops.repeat(obs, "b h d -> b t h d", t=self.ft_denoising_steps) - - # flatten the first two dimensions - if isinstance(obs, dict): - cond = obs - for key in cond: - cond[key] = einops.rearrange(cond[key], "b t ... -> (b t) ...") - cond[key] = cond[key][:, : self.cond_steps] - else: - cond = einops.rearrange(obs, "b t h d -> (b t) h d") - cond = cond[:, : self.cond_steps] + # Repeat cond for denoising_steps, flatten batch and time dimensions + cond = { + key: cond[key] + .unsqueeze(1) + .repeat(1, self.ft_denoising_steps, *(1,) * (cond[key].ndim - 1)) + .flatten(start_dim=0, end_dim=1) + for key in cond + } # less memory usage than einops? # Repeat t for batch dim, keep it 1-dim if self.use_ddim: @@ -393,8 +385,8 @@ class VPGDiffusion(DiffusionModel): chains_next = chains[:, 1:] # Flatten first two dimensions - chains_prev = chains_prev.reshape(-1, self.horizon_steps, self.transition_dim) - chains_next = chains_next.reshape(-1, self.horizon_steps, self.transition_dim) + chains_prev = chains_prev.reshape(-1, self.horizon_steps, self.action_dim) + chains_next = chains_next.reshape(-1, self.horizon_steps, self.action_dim) # Forward pass with previous chains next_mean, logvar, eta = self.p_mean_var( @@ -414,44 +406,35 @@ class VPGDiffusion(DiffusionModel): return log_prob, eta return log_prob - def loss(self, obs, chains, reward): + def loss(self, cond, chains, reward): """ REINFORCE loss. Not used right now. Args: - obs: (n_steps, n_envs, obs_dim) - chains: (n_steps, n_envs, num_denoising_step+1, horizon_step, action_dim) - reward (to go): (n_steps, n_envs) + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + chains: (B, K+1, Ta, Da) + reward (to go): (b,) """ - if torch.is_tensor(reward): - assert not reward.requires_grad - - n_steps, n_envs, _ = obs.shape - - # Flatten first two dimensions - obs = einops.rearrange(obs, "s e d -> (s e) d") - chains = einops.rearrange(chains, "s e t h d -> (s e) t h d") - reward = reward.reshape(-1) - # Get advantage with torch.no_grad(): - value = self.critic(obs).squeeze() + value = self.critic(cond).squeeze() advantage = reward - value # Get logprobs for denoising steps from T-1 to 0 - logprobs, eta = self.get_logprobs(obs, chains, get_ent=True) - # (n_steps x n_envs x denoising_steps) x horizon_steps x (obs_dim+action_dim) + logprobs, eta = self.get_logprobs(cond, chains, get_ent=True) + # (n_steps x n_envs x K) x Ta x (Do+Da) # Ignore obs dimension, and then sum over action dimension logprobs = logprobs[:, :, : self.action_dim].sum(-1) - # -> (n_steps x n_envs x denoising_steps) x horizon_steps + # -> (n_steps x n_envs x K) x Ta + + # -> (n_steps x n_envs) x K x Ta + logprobs = logprobs.reshape((-1, self.denoising_steps, self.horizon_steps)) - # -> (n_steps x n_envs) x denoising_steps x horizon_steps - logprobs = logprobs.reshape( - (n_steps * n_envs, self.denoising_steps, self.horizon_steps) - ) # Sum/avg over denoising steps - logprobs = logprobs.mean(-2) # -> (n_steps x n_envs) x horizon_steps + logprobs = logprobs.mean(-2) # -> (n_steps x n_envs) x Ta # Sum/avg over horizon steps logprobs = logprobs.mean(-1) # -> (n_steps x n_envs) @@ -460,6 +443,6 @@ class VPGDiffusion(DiffusionModel): loss_actor = torch.mean(-logprobs * advantage) # Train critic to predict state value - pred = self.critic(obs).squeeze() + pred = self.critic(cond).squeeze() loss_critic = F.mse_loss(pred, reward) return loss_actor, loss_critic, eta diff --git a/model/diffusion/eta.py b/model/diffusion/eta.py index b1068f7..c40a28f 100644 --- a/model/diffusion/eta.py +++ b/model/diffusion/eta.py @@ -28,14 +28,11 @@ class EtaFixed(torch.nn.Module): torch.tensor([2 * (base_eta - min_eta) / (max_eta - min_eta) - 1]) ) - def __call__(self, x): + def __call__(self, cond): """Match input batch size, but do not depend on input""" - if isinstance(x, dict): - B = x["state"].shape[0] - device = x["state"].device - else: - B = x.size(0) - device = x.device + sample_data = cond["state"] if "state" in cond else cond["rgb"] + B = len(sample_data) + device = sample_data.device eta_normalized = torch.tanh(self.eta_logit) # map to min and max from [-1, 1] @@ -64,14 +61,11 @@ class EtaAction(torch.nn.Module): self.min = min_eta self.max = max_eta - def __call__(self, x): + def __call__(self, cond): """Match input batch size, but do not depend on input""" - if isinstance(x, dict): - B = x["state"].shape[0] - device = x["state"].device - else: - B = x.size(0) - device = x.device + sample_data = cond["state"] if "state" in cond else cond["rgb"] + B = len(sample_data) + device = sample_data.device eta_normalized = torch.tanh(self.eta_logit) # map to min and max from [-1, 1] @@ -109,13 +103,17 @@ class EtaState(torch.nn.Module): torch.nn.init.xavier_normal_(m.weight, gain=gain) m.bias.data.fill_(0) - def __call__(self, x): - if isinstance(x, dict): + def __call__(self, cond): + if "rgb" in cond: raise NotImplementedError( "State-based eta not implemented for image-based training!" ) - x = x.view(x.size(0), -1) - eta_res = self.mlp_res(x) + # flatten history + B = len(cond["state"]) + state = cond["state"].view(B, -1) + + # forward pass + eta_res = self.mlp_res(state) eta_res = torch.tanh(eta_res) # [-1, 1] eta = eta_res + self.base # [0, 2] return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base) @@ -152,13 +150,17 @@ class EtaStateAction(torch.nn.Module): torch.nn.init.xavier_normal_(m.weight, gain=gain) m.bias.data.fill_(0) - def __call__(self, x): - if isinstance(x, dict): + def __call__(self, cond): + if "rgb" in cond: raise NotImplementedError( "State-action-based eta not implemented for image-based training!" ) - x = x.view(x.size(0), -1) - eta_res = self.mlp_res(x) + # flatten history + B = len(cond["state"]) + state = cond["state"].view(B, -1) + + # forward pass + eta_res = self.mlp_res(state) eta_res = torch.tanh(eta_res) # [-1, 1] eta = eta_res + self.base return torch.clamp(eta, self.min_res + self.base, self.max_res + self.base) diff --git a/model/diffusion/exact_likelihood.py b/model/diffusion/exact_likelihood.py index fdaa721..d435c3a 100644 --- a/model/diffusion/exact_likelihood.py +++ b/model/diffusion/exact_likelihood.py @@ -93,11 +93,12 @@ def get_likelihood_fn( """Compute an unbiased estimate to the log-likelihood in bits/dim. Args: - model: A score model. - data: A PyTorch tensor. B x horizon x transition_dim + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + data: (B x Ta x Da) Returns: - logprob: B + logprob: (B,) """ shape = data.shape B, H, A = shape @@ -118,7 +119,9 @@ def get_likelihood_fn( raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") # repeat for expectation - cond_eps = cond.repeat_interleave(num_epsilon, dim=0) + cond_eps = { + key: cond[key].repeat_interleave(num_epsilon, dim=0) for key in cond + } def ode_func(t, x): x = x[:, :-1] @@ -132,7 +135,7 @@ def get_likelihood_fn( model_fn = model_ft else: model_fn = model - x = x.view(shape) # B x horizon x transition_dim + x = x.view(shape) # B x horizon x action_dim drift = drift_fn( model_fn, x, diff --git a/model/diffusion/mlp_diffusion.py b/model/diffusion/mlp_diffusion.py index aec8dec..5b6174b 100644 --- a/model/diffusion/mlp_diffusion.py +++ b/model/diffusion/mlp_diffusion.py @@ -25,6 +25,7 @@ class VisionDiffusionMLP(nn.Module): transition_dim, horizon_steps, cond_dim, + img_cond_steps=1, time_dim=16, mlp_dims=[256, 256], activation_type="Mish", @@ -46,6 +47,8 @@ class VisionDiffusionMLP(nn.Module): if augment: self.aug = RandomShiftsAug(pad=4) self.augment = augment + self.num_img = num_img + self.img_cond_steps = img_cond_steps if spatial_emb > 0: assert spatial_emb > 1, "this is the dimension" if num_img > 1: @@ -101,36 +104,44 @@ class VisionDiffusionMLP(nn.Module): self, x, time, - cond=None, + cond: dict, **kwargs, ): """ - x: (B,T,obs_dim) + x: (B, Ta, Da) time: (B,) or int, diffusion step - cond: dict (B,cond_step,cond_dim) - output: (B,T,input_dim) + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + + TODO long term: more flexible handling of cond """ - # flatten T and input_dim - B, T, input_dim = x.shape + B, Ta, Da = x.shape + _, T_rgb, C, H, W = cond["rgb"].shape + + # flatten chunk x = x.view(B, -1) - # flatten cond_dim if exists - if cond["rgb"].ndim == 5: - rgb = einops.rearrange(cond["rgb"], "b d c h w -> (b d) c h w") + # flatten history + state = cond["state"].view(B, -1) + + # Take recent images --- sometimes we want to use fewer img_cond_steps than cond_steps (e.g., 1 image but 3 prio) + rgb = cond["rgb"][:, -self.img_cond_steps :] + + # concatenate images in cond by channels + if self.num_img > 1: + rgb = rgb.reshape(B, T_rgb, self.num_img, 3, H, W) + rgb = einops.rearrange(rgb, "b t n c h w -> b n (t c) h w") else: - rgb = cond["rgb"] - if cond["state"].ndim == 3: - state = einops.rearrange(cond["state"], "b d c -> (b d) c") - else: - state = cond["state"] + rgb = einops.rearrange(rgb, "b t c h w -> b (t c) h w") # convert rgb to float32 for augmentation rgb = rgb.float() # get vit output - pass in two images separately - if rgb.shape[1] == 6: # TODO: properly handle multiple images - rgb1 = rgb[:, :3] - rgb2 = rgb[:, 3:] + if self.num_img > 1: # TODO: properly handle multiple images + rgb1 = rgb[:, 0] + rgb2 = rgb[:, 1] if self.augment: rgb1 = self.aug(rgb1) rgb2 = self.aug(rgb2) @@ -141,7 +152,7 @@ class VisionDiffusionMLP(nn.Module): feat = torch.cat([feat1, feat2], dim=-1) else: # single image if self.augment: - rgb = self.aug(rgb) # uint8 -> float32 + rgb = self.aug(rgb) feat = self.backbone(rgb) # compress @@ -159,7 +170,7 @@ class VisionDiffusionMLP(nn.Module): # mlp out = self.mlp_mean(x) - return out.view(B, T, input_dim) + return out.view(B, Ta, Da) class DiffusionMLP(nn.Module): @@ -210,27 +221,32 @@ class DiffusionMLP(nn.Module): self, x, time, - cond=None, + cond, **kwargs, ): """ - x: (B,T,obs_dim) + x: (B, Ta, Da) time: (B,) or int, diffusion step - cond: (B,cond_step,cond_dim) - output: (B,T,input_dim) + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) """ - # flatten T and input_dim - B, T, input_dim = x.shape + B, Ta, Da = x.shape + + # flatten chunk x = x.view(B, -1) - cond = cond.view(B, -1) if cond is not None else None + + # flatten history + state = cond["state"].view(B, -1) + + # obs encoder if hasattr(self, "cond_mlp"): - cond = self.cond_mlp(cond) + state = self.cond_mlp(state) # append time and cond time = time.view(B, 1) time_emb = self.time_embedding(time).view(B, self.time_dim) - x = torch.cat([x, time_emb, cond], dim=-1) + x = torch.cat([x, time_emb, state], dim=-1) - # mlp + # mlp head out = self.mlp_mean(x) - return out.view(B, T, input_dim) + return out.view(B, Ta, Da) diff --git a/model/diffusion/sampling.py b/model/diffusion/sampling.py index 4b8e279..1ad011c 100644 --- a/model/diffusion/sampling.py +++ b/model/diffusion/sampling.py @@ -26,12 +26,6 @@ def extract(a, t, x_shape): return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def apply_obs_conditioning(x, conditions, action_dim): - for t, val in conditions.items(): - x[:, t, action_dim:] = val.clone() - return x - - def make_timesteps(batch_size, i, device): t = torch.full((batch_size,), i, device=device, dtype=torch.long) return t diff --git a/model/diffusion/unet.py b/model/diffusion/unet.py index 621b507..c5e6836 100644 --- a/model/diffusion/unet.py +++ b/model/diffusion/unet.py @@ -270,15 +270,22 @@ class Unet1D(nn.Module): **kwargs, ): """ - x: (B,T,input_dim) + x: (B, Ta, act_dim) time: (B,) or int, diffusion step - cond: (B,obs_step,cond_dim) - output: (B,T,input_dim) + cond: dict with key state/rgb; more recent obs at the end + state: (B, To, obs_dim) """ + B = len(x) + + # move chunk dim to the end x = einops.rearrange(x, "b h t -> b t h") - cond = cond.view(cond.shape[0], -1) + + # flatten history + state = cond["state"].view(B, -1) + + # obs encoder if hasattr(self, "cond_mlp"): - cond = self.cond_mlp(cond) + state = self.cond_mlp(state) # 1. time if not torch.is_tensor(time): @@ -288,7 +295,7 @@ class Unet1D(nn.Module): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML time = time.expand(x.shape[0]) global_feature = self.time_mlp(time) - global_feature = torch.cat([global_feature, cond], axis=-1) + global_feature = torch.cat([global_feature, state], axis=-1) # encode local features h_local = list() diff --git a/model/rl/gaussian_ppo.py b/model/rl/gaussian_ppo.py index 5d3cdeb..a7e3be8 100644 --- a/model/rl/gaussian_ppo.py +++ b/model/rl/gaussian_ppo.py @@ -1,6 +1,14 @@ """ PPO for Gaussian policy. +To: observation sequence length +Ta: action chunk size +Do: observation dimension +Da: action dimension + +C: image channels +H, W: image height and width + """ from typing import Optional @@ -41,8 +49,10 @@ class PPO_Gaussian(VPG_Gaussian): """ PPO loss - obs: (B, obs_step, obs_dim) - actions: (B, horizon_step, action_dim) + obs: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + actions: (B, Ta, Da) returns: (B, ) values: (B, ) advantages: (B,) diff --git a/model/rl/gaussian_rwr.py b/model/rl/gaussian_rwr.py index fd37d83..2d0a1ad 100644 --- a/model/rl/gaussian_rwr.py +++ b/model/rl/gaussian_rwr.py @@ -29,9 +29,8 @@ class RWR_Gaussian(GaussianModel): # override def loss(self, actions, obs, reward_weights): - cond = obs - B = cond.shape[0] - means, scales = self.network(cond) + B = len(obs) + means, scales = self.network(obs) dist = D.Normal(loc=means, scale=scales) log_prob = dist.log_prob(actions.view(B, -1)).mean(-1) @@ -42,16 +41,8 @@ class RWR_Gaussian(GaussianModel): # override @torch.no_grad() def forward(self, cond, deterministic=False, **kwargs): - """ - Args: - cond: (batch_size, horizon, obs_dim) - - Return: - actions: (batch_size, horizon_steps, transition_dim) - """ - B = cond.shape[0] actions = super().forward( - cond=cond.view(B, -1), + cond=cond, deterministic=deterministic, randn_clip_value=self.randn_clip_value, ) diff --git a/model/rl/gaussian_vpg.py b/model/rl/gaussian_vpg.py index aecb71d..34b6d98 100644 --- a/model/rl/gaussian_vpg.py +++ b/model/rl/gaussian_vpg.py @@ -15,26 +15,14 @@ class VPG_Gaussian(GaussianModel): self, actor, critic, - cond_steps=1, randn_clip_value=10, - network_path=None, **kwargs, ): super().__init__(network=actor, **kwargs) - self.cond_steps = cond_steps self.randn_clip_value = randn_clip_value # Value function for obs - simple MLP self.critic = critic.to(self.device) - if network_path is not None: - checkpoint = torch.load( - network_path, map_location=self.device, weights_only=True - ) - self.load_state_dict( - checkpoint["model"], - strict=False, - ) - logging.info("Loaded actor from %s", network_path) # Re-name network to actor self.actor_ft = actor @@ -44,15 +32,31 @@ class VPG_Gaussian(GaussianModel): for param in self.actor.parameters(): param.requires_grad = False + # ---------- Sampling ----------# + + @torch.no_grad() + def forward( + self, + cond, + deterministic=False, + use_base_policy=False, + ): + return super().forward( + cond=cond, + deterministic=deterministic, + randn_clip_value=self.randn_clip_value, + network_override=self.actor if use_base_policy else None, + ) + + # ---------- RL training ----------# + def get_logprobs( self, cond, actions, use_base_policy=False, ): - B, T, D = actions.shape - if not isinstance(cond, dict): - cond = cond.view(B, -1) + B = len(actions) dist = self.forward_train( cond, deterministic=False, @@ -66,22 +70,3 @@ class VPG_Gaussian(GaussianModel): def loss(self, obs, actions, reward): raise NotImplementedError - - @torch.no_grad() - def forward( - self, - cond, - deterministic=False, - use_base_policy=False, - ): - if isinstance(cond, dict): - B = cond["state"].shape[0] - else: - B = cond.shape[0] - cond = cond.view(B, -1) - return super().forward( - cond=cond, - deterministic=deterministic, - randn_clip_value=self.randn_clip_value, - network_override=self.actor if use_base_policy else None, - ) diff --git a/model/rl/gmm_ppo.py b/model/rl/gmm_ppo.py index a92276a..d732972 100644 --- a/model/rl/gmm_ppo.py +++ b/model/rl/gmm_ppo.py @@ -1,6 +1,14 @@ """ PPO for GMM policy. +To: observation sequence length +Ta: action chunk size +Do: observation dimension +Da: action dimension + +C: image channels +H, W: image height and width + """ from typing import Optional @@ -41,8 +49,10 @@ class PPO_GMM(VPG_GMM): """ PPO loss - obs: (B, obs_step, obs_dim) - actions: (B, horizon_step, action_dim) + obs: dict with key state/rgb; more recent obs at the end + state: (B, To, Do) + rgb: (B, To, C, H, W) + actions: (B, Ta, Da) returns: (B, ) values: (B, ) advantages: (B,) diff --git a/model/rl/gmm_vpg.py b/model/rl/gmm_vpg.py index fc70b79..8f095c2 100644 --- a/model/rl/gmm_vpg.py +++ b/model/rl/gmm_vpg.py @@ -8,36 +8,35 @@ class VPG_GMM(GMMModel): self, actor, critic, - cond_steps=1, - network_path=None, **kwargs, ): super().__init__(network=actor, **kwargs) - self.cond_steps = cond_steps # Re-name network to actor self.actor_ft = actor # Value function for obs - simple MLP self.critic = critic.to(self.device) - if network_path is not None: - checkpoint = torch.load( - network_path, map_location=self.device, weights_only=True - ) - self.load_state_dict( - checkpoint["model"], - strict=False, - ) - logging.info("Loaded actor from %s", network_path) + + # ---------- Sampling ----------# + + @torch.no_grad() + def forward(self, cond, deterministic=False): + return super().forward( + cond=cond, + deterministic=deterministic, + ) + + # ---------- RL training ----------# def get_logprobs( self, cond, actions, ): - B, T, D = actions.shape + B = len(actions) dist, entropy, std = self.forward_train( - cond.view(B, -1), + cond, deterministic=False, ) log_prob = dist.log_prob(actions.view(B, -1)) @@ -45,12 +44,3 @@ class VPG_GMM(GMMModel): def loss(self, obs, chains, reward): raise NotImplementedError - - # override to diffuse over action only - @torch.no_grad() - def forward(self, cond, deterministic=False): - B = cond.shape[0] - return super().forward( - cond=cond.view(B, -1), - deterministic=deterministic, - ) diff --git a/script/dataset/filter_d3il_avoid_data.py b/script/dataset/filter_d3il_avoid_data.py index 76a2a16..c463a45 100644 --- a/script/dataset/filter_d3il_avoid_data.py +++ b/script/dataset/filter_d3il_avoid_data.py @@ -245,7 +245,7 @@ def make_dataset( plot(out_val, name="val-trajs.png") # Save to np file - save_train_path = os.path.join(save_dir, save_name_prefix + "train.pkl") + save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") with open(save_train_path, "wb") as f: pickle.dump(out_train, f) diff --git a/script/dataset/get_d4rl_dataset.py b/script/dataset/get_d4rl_dataset.py index 83fc967..24cf556 100644 --- a/script/dataset/get_d4rl_dataset.py +++ b/script/dataset/get_d4rl_dataset.py @@ -142,7 +142,7 @@ def make_dataset(env_name, save_dir, save_name_prefix, val_split, logger): prev_index = cur_index + 1 # Save to np file - save_train_path = os.path.join(save_dir, save_name_prefix + "train.pkl") + save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") with open(save_train_path, "wb") as f: pickle.dump(out_train, f) diff --git a/script/dataset/process_d3il_dataset.py b/script/dataset/process_d3il_dataset.py index bf2dac3..bd2e662 100644 --- a/script/dataset/process_d3il_dataset.py +++ b/script/dataset/process_d3il_dataset.py @@ -192,7 +192,7 @@ def make_dataset(load_path, save_dir, save_name_prefix, env_type, val_split): plot(out_val, name="val-trajs.png") # Save to np file - save_train_path = os.path.join(save_dir, save_name_prefix + "train.pkl") + save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") with open(save_train_path, "wb") as f: pickle.dump(out_train, f) diff --git a/script/dataset/process_robomimic_dataset.py b/script/dataset/process_robomimic_dataset.py index c0ccc5c..5cbc4b4 100644 --- a/script/dataset/process_robomimic_dataset.py +++ b/script/dataset/process_robomimic_dataset.py @@ -303,7 +303,7 @@ def make_dataset( val_episode_reward_all.append(np.sum(data_traj["rewards"])) # Save to np file - save_train_path = os.path.join(save_dir, save_name_prefix + "train.pkl") + save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz") save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl") with open(save_train_path, "wb") as f: pickle.dump(out_train, f)