simplify pre-training dataset, use npz

This commit is contained in:
allenzren 2024-09-08 17:52:16 -04:00
parent 447c8dfd02
commit 8ce0aa1485
66 changed files with 170 additions and 324 deletions

View File

@ -57,7 +57,7 @@ export DPPO_DATA_DIR=/path/to/data -->
<!-- ``` --> <!-- ``` -->
Pre-training data for all tasks are pre-processed and can be found at [here](https://drive.google.com/drive/folders/1AXZvNQEKOrp0_jk1VLepKh_oHCg_9e3r?usp=drive_link). Pre-training script will download the data (including normalization statistics) automatically to the data directory. Pre-training data for all tasks are pre-processed and can be found at [here](https://drive.google.com/drive/folders/1AXZvNQEKOrp0_jk1VLepKh_oHCg_9e3r?usp=drive_link). Pre-training script will download the data (including normalization statistics) automatically to the data directory.
<!-- The data path follows `${DPPO_DATA_DIR}/<benchmark>/<task>/train.npz`, e.g., `${DPPO_DATA_DIR}/gym/hopper-medium-v2/train.pkl`. --> <!-- The data path follows `${DPPO_DATA_DIR}/<benchmark>/<task>/train.npz`, e.g., `${DPPO_DATA_DIR}/gym/hopper-medium-v2/train.npz`. -->
### Run pre-training with data ### Run pre-training with data
All the configs can be found under `cfg/<env>/pretrain/`. A new WandB project may be created based on `wandb.project` in the config file; set `wandb=null` in the command line to test without WandB logging. All the configs can be found under `cfg/<env>/pretrain/`. A new WandB project may be created based on `wandb.project` in the config file; set `wandb=null` in the command line to test without WandB logging.
@ -159,7 +159,8 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
## Adding your own dataset/environment ## Adding your own dataset/environment
### Pre-training data ### Pre-training data
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a pickle file containing a dictionary of `observations`, `actions`, and `traj_length`, where `observations` and `actions` have the shape of num_episode x max_episode_length x obs_dim/act_dim, and `traj_length` is a 1-D array. One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_length`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_length` is a 1-D array for indexing across num_total_steps.
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
**Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131).

View File

@ -1,110 +0,0 @@
"""
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/buffer.py
"""
import numpy as np
import torch
def atleast_2d(x):
if isinstance(x, torch.Tensor):
while x.dim() < 2:
x = x.unsqueeze(-1)
return x
else:
while x.ndim < 2:
x = np.expand_dims(x, axis=-1)
return x
class StitchedBuffer:
def __init__(
self,
sum_of_path_lengths,
device="cpu",
):
self.sum_of_path_lengths = sum_of_path_lengths
if device == "cpu":
self._dict = {
"path_lengths": np.zeros(sum_of_path_lengths, dtype=int),
}
else:
self._dict = {
"path_lengths": torch.zeros(sum_of_path_lengths, dtype=int).to(device),
}
self._count = 0
self.sum_of_path_lengths = sum_of_path_lengths
self.device = device
def __repr__(self):
return "Fields:\n" + "\n".join(
f" {key}: {val.shape}" for key, val in self.items()
)
def __getitem__(self, key):
return self._dict[key]
def __setitem__(self, key, val):
self._dict[key] = val
self._add_attributes()
@property
def n_episodes(self):
return self._count
@property
def n_steps(self):
return sum(self["path_lengths"])
def _add_keys(self, path):
if hasattr(self, "keys"):
return
self.keys = list(path.keys())
def _add_attributes(self):
"""
can access fields with `buffer.observations`
instead of `buffer['observations']`
"""
for key, val in self._dict.items():
setattr(self, key, val)
def items(self):
return {k: v for k, v in self._dict.items() if k != "path_lengths"}.items()
def _allocate(self, key, array):
assert key not in self._dict
dim = array.shape[1:] # skip batch dimension
shape = (self.sum_of_path_lengths, *dim)
if self.device == "cpu":
self._dict[key] = np.zeros(shape, dtype=np.float32)
else:
self._dict[key] = torch.zeros(shape, dtype=torch.float32).to(self.device)
# print(f'[ utils/mujoco ] Allocated {key} with size {shape}')
def add_path(self, path):
path_length = len(path["observations"])
# assert path_length <= self.sum_of_path_lengths
## if first path added, set keys based on contents
self._add_keys(path)
## add tracked keys in path
for key in self.keys:
array = atleast_2d(path[key])
if key not in self._dict:
self._allocate(key, array)
self._dict[key][self._count : self._count + path_length] = array
## record path length
self._dict["path_lengths"][
self._count : self._count + path_length
] = path_length
## increment path counter
self._count += path_length
def finalize(self):
self._add_attributes()

View File

@ -8,7 +8,6 @@ No normalization is applied here --- we always normalize the data when pre-proce
""" """
from collections import namedtuple from collections import namedtuple
from tqdm import tqdm
import numpy as np import numpy as np
import torch import torch
import logging import logging
@ -17,11 +16,7 @@ import random
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from .buffer import StitchedBuffer
Batch = namedtuple("Batch", "trajectories conditions") Batch = namedtuple("Batch", "trajectories conditions")
ValueBatch = namedtuple("ValueBatch", "trajectories conditions values")
class StitchedSequenceDataset(torch.utils.data.Dataset): class StitchedSequenceDataset(torch.utils.data.Dataset):
@ -32,7 +27,7 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
(tuple of) dimension of observation, action, images, etc. (tuple of) dimension of observation, action, images, etc.
Example: Example:
Observations: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------] states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------] Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
""" """
@ -43,67 +38,56 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
cond_steps=1, cond_steps=1,
max_n_episodes=10000, max_n_episodes=10000,
use_img=False, use_img=False,
device="cpu", device="cuda:0",
): ):
self.horizon_steps = horizon_steps self.horizon_steps = horizon_steps
self.cond_steps = cond_steps self.cond_steps = cond_steps
self.device = device self.device = device
self.use_img = use_img
# Load dataset to device specified # Load dataset to device specified
if dataset_path.endswith(".npz"): if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=True) dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
else: else:
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset = pickle.load(f) dataset = pickle.load(f)
num_episodes = dataset["observations"].shape[0] traj_lengths = dataset["traj_lengths"] # 1-D array
total_num_steps = np.sum(traj_lengths[:max_n_episodes])
# Get the sum total of the valid trajectories' lengths
traj_lengths = dataset["traj_length"]
sum_of_path_lengths = np.sum(traj_lengths)
self.sum_of_path_lengths = sum_of_path_lengths
fields = StitchedBuffer(sum_of_path_lengths, device)
for i in tqdm(
range(min(max_n_episodes, num_episodes)), desc="Loading trajectories"
):
traj_length = traj_lengths[i]
episode = {
"observations": dataset["observations"][i][:traj_length],
"actions": dataset["actions"][i][:traj_length],
"episode_ids": i * np.ones(traj_length),
}
if use_img:
episode["images"] = dataset["images"][i][:traj_length]
for key, val in episode.items():
if device == "cpu":
episode[key] = val
else:
# if None array, save as empty tensor
if np.all(np.equal(episode[key], None)):
episode[key] = torch.empty(episode[key].shape).to(device)
else:
if key == "images":
episode[key] = torch.tensor(val, dtype=torch.uint8).to(
device
)
# (, H, W, C) -> (, C, H, W)
episode[key] = episode[key].permute(0, 3, 1, 2)
else:
episode[key] = torch.tensor(val, dtype=torch.float32).to(
device
)
fields.add_path(episode)
fields.finalize()
# Set up indices for sampling
self.indices = self.make_indices(traj_lengths, horizon_steps) self.indices = self.make_indices(traj_lengths, horizon_steps)
self.obs_dim = fields.observations.shape[-1]
self.action_dim = fields.actions.shape[-1] # Extract states and actions up to max_n_episodes
self.fields = fields self.states = (
self.n_episodes = fields.n_episodes torch.from_numpy(dataset["states"][:total_num_steps]).float().to(device)
self.path_lengths = fields.path_lengths ) # (total_num_steps, obs_dim)
self.traj_lengths = traj_lengths self.actions = (
self.use_img = use_img torch.from_numpy(dataset["actions"][:total_num_steps]).float().to(device)
log.info(fields) ) # (total_num_steps, action_dim)
log.info(f"Loaded dataset from {dataset_path}")
log.info(f"Number of episodes: {min(max_n_episodes, len(traj_lengths))}")
log.info(f"States shape/type: {self.states.shape, self.states.dtype}")
log.info(f"Actions shape/type: {self.actions.shape, self.actions.dtype}")
if self.use_img:
self.images = torch.from_numpy(dataset["images"][:total_num_steps]).to(
device
) # (total_num_steps, C, H, W)
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
def __getitem__(self, idx):
start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[start:end]
actions = self.actions[start:end]
if self.use_img:
images = self.images[start:end]
conditions = {
1 - self.cond_steps: {"state": states[0], "rgb": images[0]}
} # TODO: allow obs history, -1, -2, ...
else:
conditions = {1 - self.cond_steps: states[0]}
batch = Batch(actions, conditions)
return batch
def make_indices(self, traj_lengths, horizon_steps): def make_indices(self, traj_lengths, horizon_steps):
""" """
@ -119,44 +103,12 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
return indices return indices
def set_train_val_split(self, train_split): def set_train_val_split(self, train_split):
"""Not doing validation right now"""
num_train = int(len(self.indices) * train_split) num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train) train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices] val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
self.indices = train_indices self.indices = train_indices
return val_indices return val_indices
def set_indices(self, indices):
self.indices = indices
def get_conditions(self, observations, images=None):
"""
condition on current observation for planning. Take into account the number of conditioning steps.
"""
if images is not None:
return {
1 - self.cond_steps: {"state": observations[0], "rgb": images[0]}
} # TODO: allow obs history, -1, -2, ...
else:
return {1 - self.cond_steps: observations[0]}
def __len__(self): def __len__(self):
return len(self.indices) return len(self.indices)
def __getitem__(self, idx, eps=1e-4):
raise NotImplementedError("Get item defined in subclass.")
class StitchedActionSequenceDataset(StitchedSequenceDataset):
"""Only use action trajectory, and then obs_cond for current observation"""
def __getitem__(self, idx):
start = self.indices[idx]
end = start + self.horizon_steps
observations = self.fields.observations[start:end]
actions = self.fields.actions[start:end]
images = None
if self.use_img:
images = self.fields.images[start:end]
conditions = self.get_conditions(observations, images)
batch = Batch(actions, conditions)
return batch

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: avoid_m1_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: avoid_m1_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -63,7 +63,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m1_pre_gaussian_mlp_ta${horizon_steps} name: avoid_m1_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m1_pre_gmm_mlp_ta${horizon_steps} name: avoid_m1_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m1/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: avoid_m2_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: avoid_m2_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -63,7 +63,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m2_pre_gaussian_mlp_ta${horizon_steps} name: avoid_m2_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m2_pre_gmm_mlp_ta${horizon_steps} name: avoid_m2_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m2/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: avoid_m3_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: avoid_m3_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -63,7 +63,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m3_pre_gaussian_mlp_ta${horizon_steps} name: avoid_m3_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: avoid_m3_pre_gmm_mlp_ta${horizon_steps} name: avoid_m3_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/d3il-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/d3il/avoid_m3/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -67,7 +67,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -66,7 +66,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -67,7 +67,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -66,7 +66,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -66,7 +66,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -65,7 +65,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -66,7 +66,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/furniture/${task}_${randomness}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/gym/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -61,7 +61,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -82,7 +82,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -53,7 +53,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -74,7 +74,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_transformer_ta${horizon_steps} name: ${env}_pre_gaussian_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_mlp_ta${horizon_steps} name: ${env}_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_transformer_ta${horizon_steps} name: ${env}_pre_gmm_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -61,7 +61,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -82,7 +82,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -53,7 +53,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -74,7 +74,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_transformer_ta${horizon_steps} name: ${env}_pre_gaussian_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_mlp_ta${horizon_steps} name: ${env}_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_transformer_ta${horizon_steps} name: ${env}_pre_gmm_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -62,7 +62,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -82,7 +82,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -53,7 +53,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -74,7 +74,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_transformer_ta${horizon_steps} name: ${env}_pre_gaussian_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_mlp_ta${horizon_steps} name: ${env}_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_transformer_ta${horizon_steps} name: ${env}_pre_gmm_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -61,7 +61,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -83,7 +83,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_diffusion_agent.TrainDiffusionAgent
name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps} name: ${env}_pre_diffusion_unet_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -64,7 +64,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -53,7 +53,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps} name: ${env}_pre_gaussian_mlp_img_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}-img/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -75,7 +75,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
use_img: True use_img: True
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gaussian_transformer_ta${horizon_steps} name: ${env}_pre_gaussian_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_mlp_ta${horizon_steps} name: ${env}_pre_gmm_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -55,7 +55,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -7,7 +7,7 @@ _target_: agent.pretrain.train_gaussian_agent.TrainGaussianAgent
name: ${env}_pre_gmm_transformer_ta${horizon_steps} name: ${env}_pre_gmm_transformer_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed} logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.pkl train_dataset_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env}/train.npz
seed: 42 seed: 42
device: cuda:0 device: cuda:0
@ -57,7 +57,7 @@ ema:
decay: 0.995 decay: 0.995
train_dataset: train_dataset:
_target_: agent.dataset.sequence.StitchedActionSequenceDataset _target_: agent.dataset.sequence.StitchedSequenceDataset
dataset_path: ${train_dataset_path} dataset_path: ${train_dataset_path}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps} cond_steps: ${cond_steps}

View File

@ -124,6 +124,9 @@ class VisionDiffusionMLP(nn.Module):
else: else:
state = cond["state"] state = cond["state"]
# convert rgb to float32 for augmentation
rgb = rgb.float()
# get vit output - pass in two images separately # get vit output - pass in two images separately
if rgb.shape[1] == 6: # TODO: properly handle multiple images if rgb.shape[1] == 6: # TODO: properly handle multiple images
rgb1 = rgb[:, :3] rgb1 = rgb[:, :3]