squash commits

This commit is contained in:
allenzren 2024-09-11 21:09:17 -04:00
parent 8ce0aa1485
commit 2ddf63b8f5
200 changed files with 1240 additions and 1186 deletions

View File

@ -16,19 +16,21 @@ import random
log = logging.getLogger(__name__)
Batch = namedtuple("Batch", "trajectories conditions")
Batch = namedtuple("Batch", "actions conditions")
class StitchedSequenceDataset(torch.utils.data.Dataset):
"""
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
(tuple of) dimension of observation, action, images, etc.
Use the first max_n_episodes episodes (instead of random sampling)
Example:
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
"""
def __init__(
@ -36,23 +38,30 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
dataset_path,
horizon_steps=64,
cond_steps=1,
img_cond_steps=1,
max_n_episodes=10000,
use_img=False,
device="cuda:0",
):
assert (
img_cond_steps <= cond_steps
), "consider using more cond_steps than img_cond_steps"
self.horizon_steps = horizon_steps
self.cond_steps = cond_steps
self.cond_steps = cond_steps # states (proprio, etc.)
self.img_cond_steps = img_cond_steps
self.device = device
self.use_img = use_img
# Load dataset to device specified
if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
else:
elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f:
dataset = pickle.load(f)
traj_lengths = dataset["traj_lengths"] # 1-D array
total_num_steps = np.sum(traj_lengths[:max_n_episodes])
else:
raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
total_num_steps = np.sum(traj_lengths)
# Set up indices for sampling
self.indices = self.make_indices(traj_lengths, horizon_steps)
@ -75,35 +84,51 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
def __getitem__(self, idx):
start = self.indices[idx]
"""
repeat states/images if using history observation at the beginning of the episode
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps
states = self.states[start:end]
states = self.states[(start - num_before_start) : end]
actions = self.actions[start:end]
states = torch.stack(
[
states[min(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states}
if self.use_img:
images = self.images[start:end]
conditions = {
1 - self.cond_steps: {"state": states[0], "rgb": images[0]}
} # TODO: allow obs history, -1, -2, ...
else:
conditions = {1 - self.cond_steps: states[0]}
images = self.images[(start - num_before_start) : end]
images = torch.stack(
[
images[min(num_before_start - t, 0)]
for t in reversed(range(self.img_cond_steps))
]
)
conditions["rgb"] = images
batch = Batch(actions, conditions)
return batch
def make_indices(self, traj_lengths, horizon_steps):
"""
makes indices for sampling from dataset;
each index maps to a datapoint
each index maps to a datapoint, also save the number of steps before it within the same trajectory
"""
indices = []
cur_traj_index = 0
for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1
indices += list(range(cur_traj_index, max_start))
indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
]
cur_traj_index += traj_length
return indices
def set_train_val_split(self, train_split):
"""Not doing validation right now"""
"""
Not doing validation right now
"""
num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]

View File

@ -3,6 +3,8 @@ Advantage-weighted regression (AWR) for diffusion policy.
Advantage = discounted-reward-to-go - V(s)
Do not support pixel input right now.
"""
import os
@ -131,6 +133,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -145,9 +148,10 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -155,7 +159,6 @@ class TrainAWRDiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
@ -165,16 +168,19 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode,
)
.cpu()
.numpy()
) # n_env x horizon x act
)
action_venv = samples[:, : self.act_steps]
# Apply multi-step action
@ -184,7 +190,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer
obs_buffer.append(prev_obs_venv)
obs_buffer.append(prev_obs_venv["state"])
action_buffer.append(action_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv)
@ -230,59 +236,46 @@ class TrainAWRDiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer))
obs_trajs = np.array(deepcopy(obs_buffer)) # assume only state
reward_trajs = np.array(deepcopy(reward_buffer))
dones_trajs = np.array(deepcopy(done_buffer))
obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
)
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
values_trajs = values_t.reshape(-1, self.n_envs)
values_trajs = np.array(
self.model.critic({"state": obs_t}).detach().cpu().numpy()
).reshape(-1, self.n_envs)
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
td_t = torch.from_numpy(td_trajs.flatten()).float().to(self.device)
# flatten
obs_trajs = einops.rearrange(
obs_trajs,
"s e h d -> (s e) h d",
)
td_trajs = einops.rearrange(
td_trajs,
"s e -> (s e)",
)
# Update policy and critic
# Update critic
num_batch = int(
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
)
for _ in range(num_batch // self.critic_update_ratio):
# Sample batch
inds = np.random.choice(len(obs_trajs), self.batch_size)
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
td_b = torch.from_numpy(td_trajs[inds]).float().to(self.device)
# Update critic
loss_critic = self.model.loss_critic(obs_b, td_b)
loss_critic = self.model.loss_critic(
{"state": obs_t[inds]}, td_t[inds]
)
self.critic_optimizer.zero_grad()
loss_critic.backward()
self.critic_optimizer.step()
# Update policy - use a new copy of data
obs_trajs = np.array(deepcopy(obs_buffer))
samples_trajs = np.array(deepcopy(action_buffer))
reward_trajs = np.array(deepcopy(reward_buffer))
dones_trajs = np.array(deepcopy(done_buffer))
obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
)
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
values_trajs = values_t.reshape(-1, self.n_envs)
values_trajs = np.array(
self.model.critic({"state": obs_t}).detach().cpu().numpy()
).reshape(-1, self.n_envs)
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
advantages_trajs = td_trajs - values_trajs
@ -304,7 +297,11 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Sample batch
inds = np.random.choice(len(obs_trajs), self.batch_size)
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
obs_b = {
"state": torch.from_numpy(obs_trajs[inds])
.float()
.to(self.device)
}
actions_b = (
torch.from_numpy(samples_trajs[inds]).float().to(self.device)
)
@ -347,6 +344,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -367,7 +365,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -383,7 +381,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -4,6 +4,9 @@ Model-free online RL with DIffusion POlicy (DIPO)
Applies action gradient to perturb actions towards maximizer of Q-function.
a_t <- a_t + \eta * \grad_a Q(s, a)
Do not support pixel input right now.
"""
import os
@ -90,6 +93,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -104,9 +108,10 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -114,7 +119,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
@ -124,11 +128,14 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode,
)
.cpu()
@ -144,8 +151,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
# add to buffer
for i in range(self.n_envs):
obs_buffer.append(prev_obs_venv[i])
next_obs_buffer.append(obs_venv[i])
obs_buffer.append(prev_obs_venv["state"][i])
next_obs_buffer.append(obs_venv["state"][i])
action_buffer.append(action_venv[i])
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
done_buffer.append(done_venv[i])
@ -191,8 +198,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
num_batch = self.replay_ratio
# Critic learning
@ -231,7 +238,12 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Update critic
loss_critic = self.model.loss_critic(
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
{"state": obs_b},
{"state": next_obs_b},
actions_b,
rewards_b,
dones_b,
self.gamma,
)
self.critic_optimizer.zero_grad()
loss_critic.backward()
@ -239,7 +251,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Actor learning
for _ in range(num_batch):
# Sample batch
inds = np.random.choice(len(obs_buffer), self.batch_size)
obs_b = (
@ -265,7 +276,9 @@ class TrainDIPODiffusionAgent(TrainAgent):
)
for _ in range(self.action_gradient_steps):
actions_flat.requires_grad_(True)
q_values_1, q_values_2 = self.model.critic(obs_b, actions_flat)
q_values_1, q_values_2 = self.model.critic(
{"state": obs_b}, actions_flat
)
q_values = torch.min(q_values_1, q_values_2)
action_opt_loss = -q_values.sum()
@ -291,7 +304,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
)
# Update policy with collected trajectories
loss = self.model.loss(guided_action.detach(), {0: obs_b})
loss = self.model.loss(guided_action.detach(), {"state": obs_b})
self.actor_optimizer.zero_grad()
loss.backward()
if self.itr >= self.n_critic_warmup_itr:
@ -316,6 +329,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -336,7 +350,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -352,7 +366,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -5,6 +5,9 @@ Learns a critic Q-function and backprops the expected Q-value to train the actor
pi = argmin L_d(\theta) - \alpha * E[Q(s, a)]
L_d is demonstration loss for regularization
Do not support pixel input right now.
"""
import os
@ -38,7 +41,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
lr=cfg.train.actor_lr,
weight_decay=cfg.train.actor_weight_decay,
)
# use cosine scheduler with linear warmup
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
self.actor_optimizer,
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
@ -88,6 +90,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -102,9 +105,10 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -112,7 +116,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
@ -122,11 +125,14 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode,
)
.cpu()
@ -142,8 +148,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
# add to buffer
for i in range(self.n_envs):
obs_buffer.append(prev_obs_venv[i])
next_obs_buffer.append(obs_venv[i])
obs_buffer.append(prev_obs_venv["state"][i])
next_obs_buffer.append(obs_venv["state"][i])
action_buffer.append(action_venv[i])
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
done_buffer.append(done_venv[i])
@ -189,8 +195,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
num_batch = self.replay_ratio
# Critic learning
@ -229,7 +235,12 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Update critic
loss_critic = self.model.loss_critic(
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
{"state": obs_b},
{"state": next_obs_b},
actions_b,
rewards_b,
dones_b,
self.gamma,
)
self.critic_optimizer.zero_grad()
loss_critic.backward()
@ -237,19 +248,21 @@ class TrainDQLDiffusionAgent(TrainAgent):
# get the new action and q values
samples = self.model.forward_train(
cond=obs_b.to(self.device),
cond={"state": obs_b},
deterministic=eval_mode,
)
output_venv = samples # n_env x horizon x act
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
actions_flat_b = action_venv.reshape(action_venv.shape[0], -1)
q_values_b = self.model.critic(obs_b, actions_flat_b)
action_venv = samples[:, : self.act_steps] # n_env x horizon x act
q_values_b = self.model.critic({"state": obs_b}, action_venv)
q1_new_action, q2_new_action = q_values_b
# Update policy with collected trajectories
self.actor_optimizer.zero_grad()
actor_loss = self.model.loss_actor(
obs_b, actions_b, q1_new_action, q2_new_action, self.eta
{"state": obs_b},
actions_b,
q1_new_action,
q2_new_action,
self.eta,
)
actor_loss.backward()
if self.itr >= self.n_critic_warmup_itr:
@ -275,6 +288,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -295,7 +309,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -311,7 +325,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -1,6 +1,8 @@
"""
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
Do not support pixel input right now.
"""
import os
@ -11,7 +13,6 @@ import torch
import logging
import wandb
from copy import deepcopy
import random
log = logging.getLogger(__name__)
from util.timer import Timer
@ -98,8 +99,8 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size)
done_buffer = deque(maxlen=self.buffer_size)
first_buffer = deque(maxlen=self.buffer_size)
@ -107,6 +108,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -121,9 +123,10 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -131,7 +134,6 @@ class TrainIDQLDiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
@ -141,11 +143,14 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode and self.eval_deterministic,
num_sample=self.num_sample,
use_expectile_exploration=self.use_expectile_exploration,
@ -162,9 +167,9 @@ class TrainIDQLDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer
obs_buffer.append(prev_obs_venv)
obs_buffer.append(prev_obs_venv["state"])
next_obs_buffer.append(obs_venv["state"])
action_buffer.append(action_venv)
next_obs_buffer.append(obs_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv)
first_buffer.append(firsts_trajs[step])
@ -209,7 +214,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer))
@ -257,14 +262,21 @@ class TrainIDQLDiffusionAgent(TrainAgent):
done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device)
# update critic value function
critic_loss_v = self.model.loss_critic_v(obs_b, actions_b)
critic_loss_v = self.model.loss_critic_v(
{"state": obs_b}, actions_b
)
self.critic_v_optimizer.zero_grad()
critic_loss_v.backward()
self.critic_v_optimizer.step()
# update critic q function
critic_loss_q = self.model.loss_critic_q(
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
{"state": obs_b},
{"state": next_obs_b},
actions_b,
reward_b,
done_b,
self.gamma,
)
self.critic_q_optimizer.zero_grad()
critic_loss_q.backward()
@ -278,7 +290,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Update policy with collected trajectories - no weighting
loss = self.model.loss(
actions_b,
obs_b,
{"state": obs_b},
)
self.actor_optimizer.zero_grad()
loss.backward()
@ -305,6 +317,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -325,7 +338,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -341,7 +354,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
@ -78,7 +79,9 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
)
# Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
chains_trajs = np.empty(
(
0,
@ -91,8 +94,8 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
) # remove cond_step dim
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # save current obs
# Collect a set of trajectories from env
for step in range(self.n_steps):
@ -101,8 +104,13 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
cond=cond,
deterministic=eval_mode,
return_chain=True,
)
@ -118,14 +126,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
if self.save_full_observations:
obs_full_venv = np.vstack(
[info["full_obs"][None] for info in info_venv]
) # n_envs x n_act_steps x obs_dim
if self.save_full_observations: # state-only
obs_full_venv = np.array(
[info["full_obs"]["state"] for info in info_venv]
) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
)
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
@ -177,12 +187,22 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
# Update models
if not eval_mode:
with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory
obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten()
@ -219,7 +239,11 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -250,10 +274,12 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = einops.rearrange(
torch.tensor(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_k = {
"state": einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
}
chains_k = einops.rearrange(
torch.tensor(chains_trajs).float().to(self.device),
"s e t h d -> (s e) t h d",
@ -283,7 +309,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b]
obs_b = {"state": obs_k["state"][inds_b]}
chains_b = chains_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
@ -351,7 +377,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Plot state trajectories in D3IL
# Plot state trajectories (only in D3IL)
if (
self.itr % self.render_freq == 0
and self.n_render > 0

View File

@ -30,7 +30,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
# Set obs dim - we will save the different obs in batch in a dict
shape_meta = cfg.shape_meta
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()}
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs}
# Gradient accumulation to deal with large GPU RAM usage
self.grad_accumulate = cfg.train.grad_accumulate
@ -95,7 +95,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
key: torch.from_numpy(prev_obs_venv[key])
.float()
.to(self.device)
for key in self.obs_dims.keys()
for key in self.obs_dims
} # batch each type of obs and put into dict
samples = self.model(
cond=cond,
@ -114,7 +114,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
for k in obs_trajs.keys():
for k in obs_trajs:
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
@ -159,7 +159,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
with torch.no_grad():
# apply image randomization
@ -187,7 +187,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
for k in obs_trajs.keys():
for k in obs_trajs:
obs_k = einops.rearrange(
obs_trajs[k],
"s e ... -> (s e) ...",
@ -238,7 +238,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = {
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims.keys()
for key in self.obs_dims
}
with torch.no_grad():
next_value = (
@ -278,7 +278,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
obs_trajs[k],
"s e ... -> (s e) ...",
)
for k in obs_trajs.keys()
for k in obs_trajs
}
chains_k = einops.rearrange(
torch.tensor(chains_trajs).float().to(self.device),
@ -309,7 +309,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
chains_b = chains_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
@ -387,7 +387,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Update lr
# Update lr, min_sampling_std
if self.itr >= self.n_critic_warmup_itr:
self.actor_lr_scheduler.step()
if self.learn_eta:
@ -407,6 +407,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -427,7 +428,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -462,7 +463,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -1,6 +1,8 @@
"""
Use diffusion exact likelihood for policy gradient.
Do not support pixel input yet.
"""
import os
@ -10,6 +12,7 @@ import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
@ -43,6 +46,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
eval_mode = False
self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode
@ -58,7 +62,9 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
)
# Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty(
(
0,
@ -79,8 +85,8 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
) # remove cond_step dim
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # save current obs
# Collect a set of trajectories from env
for step in range(self.n_steps):
@ -89,8 +95,13 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
cond=cond,
deterministic=eval_mode,
return_chain=True,
)
@ -101,21 +112,23 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
samples.chains.cpu().numpy()
) # n_env x denoising x horizon x act
action_venv = output_venv[:, : self.act_steps]
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
if self.save_full_observations:
obs_full_venv = np.vstack(
[info["full_obs"][None] for info in info_venv]
) # n_envs x n_act_steps x obs_dim
if self.save_full_observations: # state-only
obs_full_venv = np.array(
[info["full_obs"]["state"] for info in info_venv]
) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
)
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
firsts_trajs[step + 1] = done_venv
@ -158,15 +171,25 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory
obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten()
@ -193,7 +216,11 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -224,10 +251,12 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = einops.rearrange(
torch.tensor(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_k = {
"state": einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
}
samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d",
@ -257,7 +286,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b]
obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
@ -318,7 +347,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Plot state trajectories
# Plot state trajectories (only in D3IL)
if (
self.itr % self.render_freq == 0
and self.n_render > 0
@ -354,6 +383,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["chains_trajs"] = chains_trajs
run_results[-1]["reward_trajs"] = reward_trajs
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -374,7 +404,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -399,7 +429,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
import logging
import wandb
import math
log = logging.getLogger(__name__)
from util.timer import Timer
@ -55,7 +56,9 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
)
# Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty(
(
0,
@ -67,8 +70,8 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
) # remove cond_step dim
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # save current obs
# Collect a set of trajectories from env
for step in range(self.n_steps):
@ -77,26 +80,33 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
cond=cond,
deterministic=eval_mode,
)
output_venv = samples.cpu().numpy()
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
if self.save_full_observations:
obs_full_venv = np.vstack(
[info["full_obs"][None] for info in info_venv]
) # n_envs x n_act_steps x obs_dim
if self.save_full_observations: # state-only
obs_full_venv = np.array(
[info["full_obs"]["state"] for info in info_venv]
) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
)
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
firsts_trajs[step + 1] = done_venv
@ -144,15 +154,25 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory
obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_trajs["state"] = (
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
)
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten()
@ -184,7 +204,11 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad():
next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -215,10 +239,12 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
returns_trajs = advantages_trajs + values_trajs
# k for environment step
obs_k = einops.rearrange(
torch.tensor(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
obs_k = {
"state": einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
}
samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d",
@ -250,7 +276,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b]
obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
@ -313,7 +339,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
)
# Plot state trajectories
# Plot state trajectories (only in D3IL)
if (
self.itr % self.render_freq == 0
and self.n_render > 0

View File

@ -94,8 +94,8 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
key: torch.from_numpy(prev_obs_venv[key])
.float()
.to(self.device)
for key in self.obs_dims.keys()
} # batch each type of obs and put into dict
for key in self.obs_dims
}
samples = self.model(
cond=cond,
deterministic=eval_mode,
@ -107,7 +107,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
for k in obs_trajs.keys():
for k in obs_trajs:
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
@ -152,7 +152,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
with torch.no_grad():
# apply image randomization
@ -180,7 +180,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
for k in obs_trajs.keys():
for k in obs_trajs:
obs_k = einops.rearrange(
obs_trajs[k],
"s e ... -> (s e) ...",
@ -226,7 +226,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = {
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims.keys()
for key in self.obs_dims
}
with torch.no_grad():
next_value = (
@ -266,7 +266,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
obs_trajs[k],
"s e ... -> (s e) ...",
)
for k in obs_trajs.keys()
for k in obs_trajs
}
samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device),
@ -297,7 +297,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b]
values_b = values_k[inds_b]
@ -383,6 +383,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -403,7 +404,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -437,7 +438,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -1,6 +1,8 @@
"""
QSM (Q-Score Matching) for diffusion policy.
Do not support pixel input right now.
"""
import os
@ -75,8 +77,8 @@ class TrainQSMDiffusionAgent(TrainAgent):
# make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size)
done_buffer = deque(maxlen=self.buffer_size)
first_buffer = deque(maxlen=self.buffer_size)
@ -84,6 +86,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -98,9 +101,10 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -108,7 +112,6 @@ class TrainQSMDiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
@ -118,11 +121,14 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode,
)
.cpu()
@ -137,9 +143,9 @@ class TrainQSMDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer
obs_buffer.append(prev_obs_venv)
obs_buffer.append(prev_obs_venv["state"])
next_obs_buffer.append(obs_venv["state"])
action_buffer.append(action_venv)
next_obs_buffer.append(obs_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv)
first_buffer.append(firsts_trajs[step])
@ -184,7 +190,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer))
@ -233,7 +239,12 @@ class TrainQSMDiffusionAgent(TrainAgent):
# update critic q function
critic_loss = self.model.loss_critic(
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
{"state": obs_b},
{"state": next_obs_b},
actions_b,
reward_b,
done_b,
self.gamma,
)
self.critic_optimizer.zero_grad()
critic_loss.backward()
@ -246,7 +257,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Update policy with collected trajectories
loss = self.model.loss_actor(
obs_b,
{"state": obs_b},
actions_b,
self.q_grad_coeff,
)
@ -274,6 +285,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -294,7 +306,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -310,7 +322,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -1,6 +1,8 @@
"""
Reward-weighted regression (RWR) for diffusion policy.
Do not support pixel input right now.
"""
import os
@ -54,6 +56,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Start training loop
timer = Timer()
run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
@ -68,9 +71,10 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
last_itr_eval = eval_mode
# Reset env at the beginning of an iteration
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
@ -78,11 +82,11 @@ class TrainRWRDiffusionAgent(TrainAgent):
firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset
)
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Holders
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
# Holder
obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty(
(
0,
@ -91,6 +95,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
self.action_dim,
)
)
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
@ -99,24 +104,29 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=torch.from_numpy(prev_obs_venv)
.float()
.to(self.device),
cond=cond,
deterministic=eval_mode,
)
.cpu()
.numpy()
) # n_env x horizon x act
action_venv = samples[:, : self.act_steps]
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
samples_trajs = np.vstack((samples_trajs, samples[None]))
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv
)
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv
@ -133,7 +143,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
if len(episodes_start_end) > 0:
# Compute transitions for completed trajectories
obs_trajs_split = [
obs_trajs[start : end + 1, env_ind]
{"state": obs_trajs["state"][start : end + 1, env_ind]}
for env_ind, start, end in episodes_start_end
]
samples_trajs_split = [
@ -183,17 +193,20 @@ class TrainRWRDiffusionAgent(TrainAgent):
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update
# Update models
if not eval_mode:
# Tensorize data and put them to device
# k for environment step
obs_k = (
torch.tensor(np.concatenate(obs_trajs_split))
obs_k = {
"state": torch.tensor(
np.concatenate(
[obs_traj["state"] for obs_traj in obs_trajs_split]
)
)
.float()
.to(self.device)
)
}
samples_k = (
torch.tensor(np.concatenate(samples_trajs_split))
.float()
@ -204,19 +217,15 @@ class TrainRWRDiffusionAgent(TrainAgent):
returns_trajs_split = (
returns_trajs_split - np.mean(returns_trajs_split)
) / (returns_trajs_split.std() + 1e-3)
rewards_k = (
torch.tensor(returns_trajs_split)
.float()
.to(self.device)
.reshape(-1)
)
rewards_k_scaled = torch.exp(self.beta * rewards_k)
rewards_k_scaled.clamp_(max=self.max_reward_weight)
# rewards_k_scaled = rewards_k_scaled / rewards_k_scaled.mean()
# Update policy and critic
total_steps = len(rewards_k_scaled)
inds_k = np.arange(total_steps)
@ -229,7 +238,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
start = batch * self.batch_size
end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b]
obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b]
rewards_b = rewards_k_scaled[inds_b]
@ -261,6 +270,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
}
)
if self.itr % self.log_freq == 0:
time = timer()
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -281,7 +291,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
@ -295,7 +305,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
)
run_results[-1]["loss"] = loss
run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer()
run_results[-1]["time"] = time
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256]
residual_style: True
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False
fixed_std: 0.1
num_modes: ${num_modes}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False
fixed_std: 0.1
num_modes: ${num_modes}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False
fixed_std: 0.1
num_modes: ${num_modes}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -100,10 +100,9 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -58,9 +58,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -58,9 +58,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
critic_v:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -99,20 +99,18 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
critic_v:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -98,20 +98,18 @@ model:
time_dim: 16
mlp_dims: [512, 512, 512]
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
critic_v:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
cond_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device}
ema:

View File

@ -94,13 +94,12 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -95,13 +95,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -94,13 +94,12 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -97,19 +97,18 @@ model:
_target_: model.common.critic.CriticObsAct
action_dim: ${action_dim}
action_steps: ${act_steps}
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
critic_v:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -100,15 +100,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -20,6 +20,7 @@ transition_dim: ${action_dim}
denoising_steps: 100
ft_denoising_steps: 5
cond_steps: 1
img_cond_steps: 1
horizon_steps: 4
act_steps: 4
use_ddim: True
@ -121,6 +122,7 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg:
patch_size: 8
depth: 1
@ -133,6 +135,7 @@ model:
time_dim: 32
mlp_dims: [512, 512, 512]
residual_style: True
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
@ -143,6 +146,7 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg:
patch_size: 8
depth: 1
@ -150,15 +154,14 @@ model:
num_heads: 4
embed_style: embed2
embed_norm: 0
obs_dim: ${obs_dim}
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -103,15 +103,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -94,10 +94,9 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -18,6 +18,7 @@ obs_dim: 9
action_dim: 7
transition_dim: ${action_dim}
cond_steps: 1
img_cond_steps: 1
horizon_steps: 4
act_steps: 4
@ -100,6 +101,7 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg:
patch_size: 8
depth: 1
@ -115,6 +117,7 @@ model:
learn_fixed_std: True
std_min: 0.01
std_max: 0.2
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
@ -125,6 +128,7 @@ model:
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg:
patch_size: 8
depth: 1
@ -132,10 +136,10 @@ model:
num_heads: 4
embed_style: embed2
embed_norm: 0
obs_dim: ${obs_dim}
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

View File

@ -26,7 +26,7 @@ env:
name: ${env_name}
best_reward_threshold_for_success: 1
max_episode_steps: 300
save_video: false
save_video: False
wrappers:
robomimic_lowdim:
normalization_path: ${normalization_path}
@ -95,10 +95,9 @@ model:
transition_dim: ${transition_dim}
critic:
_target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256]
activation_type: Mish
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device}

Some files were not shown because too many files have changed in this diff Show More