squash commits

This commit is contained in:
allenzren 2024-09-11 21:09:17 -04:00
parent 8ce0aa1485
commit 2ddf63b8f5
200 changed files with 1240 additions and 1186 deletions

View File

@ -16,19 +16,21 @@ import random
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
Batch = namedtuple("Batch", "trajectories conditions") Batch = namedtuple("Batch", "actions conditions")
class StitchedSequenceDataset(torch.utils.data.Dataset): class StitchedSequenceDataset(torch.utils.data.Dataset):
""" """
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories. Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is Use the first max_n_episodes episodes (instead of random sampling)
(tuple of) dimension of observation, action, images, etc.
Example: Example:
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------] states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------] Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
""" """
def __init__( def __init__(
@ -36,23 +38,30 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
dataset_path, dataset_path,
horizon_steps=64, horizon_steps=64,
cond_steps=1, cond_steps=1,
img_cond_steps=1,
max_n_episodes=10000, max_n_episodes=10000,
use_img=False, use_img=False,
device="cuda:0", device="cuda:0",
): ):
assert (
img_cond_steps <= cond_steps
), "consider using more cond_steps than img_cond_steps"
self.horizon_steps = horizon_steps self.horizon_steps = horizon_steps
self.cond_steps = cond_steps self.cond_steps = cond_steps # states (proprio, etc.)
self.img_cond_steps = img_cond_steps
self.device = device self.device = device
self.use_img = use_img self.use_img = use_img
# Load dataset to device specified # Load dataset to device specified
if dataset_path.endswith(".npz"): if dataset_path.endswith(".npz"):
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
else: elif dataset_path.endswith(".pkl"):
with open(dataset_path, "rb") as f: with open(dataset_path, "rb") as f:
dataset = pickle.load(f) dataset = pickle.load(f)
traj_lengths = dataset["traj_lengths"] # 1-D array else:
total_num_steps = np.sum(traj_lengths[:max_n_episodes]) raise ValueError(f"Unsupported file format: {dataset_path}")
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
total_num_steps = np.sum(traj_lengths)
# Set up indices for sampling # Set up indices for sampling
self.indices = self.make_indices(traj_lengths, horizon_steps) self.indices = self.make_indices(traj_lengths, horizon_steps)
@ -75,35 +84,51 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}") log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
def __getitem__(self, idx): def __getitem__(self, idx):
start = self.indices[idx] """
repeat states/images if using history observation at the beginning of the episode
"""
start, num_before_start = self.indices[idx]
end = start + self.horizon_steps end = start + self.horizon_steps
states = self.states[start:end] states = self.states[(start - num_before_start) : end]
actions = self.actions[start:end] actions = self.actions[start:end]
states = torch.stack(
[
states[min(num_before_start - t, 0)]
for t in reversed(range(self.cond_steps))
]
) # more recent is at the end
conditions = {"state": states}
if self.use_img: if self.use_img:
images = self.images[start:end] images = self.images[(start - num_before_start) : end]
conditions = { images = torch.stack(
1 - self.cond_steps: {"state": states[0], "rgb": images[0]} [
} # TODO: allow obs history, -1, -2, ... images[min(num_before_start - t, 0)]
else: for t in reversed(range(self.img_cond_steps))
conditions = {1 - self.cond_steps: states[0]} ]
)
conditions["rgb"] = images
batch = Batch(actions, conditions) batch = Batch(actions, conditions)
return batch return batch
def make_indices(self, traj_lengths, horizon_steps): def make_indices(self, traj_lengths, horizon_steps):
""" """
makes indices for sampling from dataset; makes indices for sampling from dataset;
each index maps to a datapoint each index maps to a datapoint, also save the number of steps before it within the same trajectory
""" """
indices = [] indices = []
cur_traj_index = 0 cur_traj_index = 0
for traj_length in traj_lengths: for traj_length in traj_lengths:
max_start = cur_traj_index + traj_length - horizon_steps + 1 max_start = cur_traj_index + traj_length - horizon_steps + 1
indices += list(range(cur_traj_index, max_start)) indices += [
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
]
cur_traj_index += traj_length cur_traj_index += traj_length
return indices return indices
def set_train_val_split(self, train_split): def set_train_val_split(self, train_split):
"""Not doing validation right now""" """
Not doing validation right now
"""
num_train = int(len(self.indices) * train_split) num_train = int(len(self.indices) * train_split)
train_indices = random.sample(self.indices, num_train) train_indices = random.sample(self.indices, num_train)
val_indices = [i for i in range(len(self.indices)) if i not in train_indices] val_indices = [i for i in range(len(self.indices)) if i not in train_indices]

View File

@ -3,6 +3,8 @@ Advantage-weighted regression (AWR) for diffusion policy.
Advantage = discounted-reward-to-go - V(s) Advantage = discounted-reward-to-go - V(s)
Do not support pixel input right now.
""" """
import os import os
@ -131,6 +133,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -145,9 +148,10 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -155,7 +159,6 @@ class TrainAWRDiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
@ -165,16 +168,19 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode, deterministic=eval_mode,
) )
.cpu() .cpu()
.numpy() .numpy()
) # n_env x horizon x act )
action_venv = samples[:, : self.act_steps] action_venv = samples[:, : self.act_steps]
# Apply multi-step action # Apply multi-step action
@ -184,7 +190,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer # add to buffer
obs_buffer.append(prev_obs_venv) obs_buffer.append(prev_obs_venv["state"])
action_buffer.append(action_venv) action_buffer.append(action_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor) reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv) done_buffer.append(done_venv)
@ -230,59 +236,46 @@ class TrainAWRDiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer)) # assume only state
obs_trajs = np.array(deepcopy(obs_buffer))
reward_trajs = np.array(deepcopy(reward_buffer)) reward_trajs = np.array(deepcopy(reward_buffer))
dones_trajs = np.array(deepcopy(done_buffer)) dones_trajs = np.array(deepcopy(done_buffer))
obs_t = einops.rearrange( obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device), torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d", "s e h d -> (s e) h d",
) )
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy()) values_trajs = np.array(
values_trajs = values_t.reshape(-1, self.n_envs) self.model.critic({"state": obs_t}).detach().cpu().numpy()
).reshape(-1, self.n_envs)
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs) td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
td_t = torch.from_numpy(td_trajs.flatten()).float().to(self.device)
# flatten # Update critic
obs_trajs = einops.rearrange(
obs_trajs,
"s e h d -> (s e) h d",
)
td_trajs = einops.rearrange(
td_trajs,
"s e -> (s e)",
)
# Update policy and critic
num_batch = int( num_batch = int(
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
) )
for _ in range(num_batch // self.critic_update_ratio): for _ in range(num_batch // self.critic_update_ratio):
# Sample batch
inds = np.random.choice(len(obs_trajs), self.batch_size) inds = np.random.choice(len(obs_trajs), self.batch_size)
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device) loss_critic = self.model.loss_critic(
td_b = torch.from_numpy(td_trajs[inds]).float().to(self.device) {"state": obs_t[inds]}, td_t[inds]
)
# Update critic
loss_critic = self.model.loss_critic(obs_b, td_b)
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
loss_critic.backward() loss_critic.backward()
self.critic_optimizer.step() self.critic_optimizer.step()
# Update policy - use a new copy of data
obs_trajs = np.array(deepcopy(obs_buffer)) obs_trajs = np.array(deepcopy(obs_buffer))
samples_trajs = np.array(deepcopy(action_buffer)) samples_trajs = np.array(deepcopy(action_buffer))
reward_trajs = np.array(deepcopy(reward_buffer)) reward_trajs = np.array(deepcopy(reward_buffer))
dones_trajs = np.array(deepcopy(done_buffer)) dones_trajs = np.array(deepcopy(done_buffer))
obs_t = einops.rearrange( obs_t = einops.rearrange(
torch.from_numpy(obs_trajs).float().to(self.device), torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d", "s e h d -> (s e) h d",
) )
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy()) values_trajs = np.array(
values_trajs = values_t.reshape(-1, self.n_envs) self.model.critic({"state": obs_t}).detach().cpu().numpy()
).reshape(-1, self.n_envs)
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs) td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
advantages_trajs = td_trajs - values_trajs advantages_trajs = td_trajs - values_trajs
@ -304,7 +297,11 @@ class TrainAWRDiffusionAgent(TrainAgent):
# Sample batch # Sample batch
inds = np.random.choice(len(obs_trajs), self.batch_size) inds = np.random.choice(len(obs_trajs), self.batch_size)
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device) obs_b = {
"state": torch.from_numpy(obs_trajs[inds])
.float()
.to(self.device)
}
actions_b = ( actions_b = (
torch.from_numpy(samples_trajs[inds]).float().to(self.device) torch.from_numpy(samples_trajs[inds]).float().to(self.device)
) )
@ -347,6 +344,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -367,7 +365,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -383,7 +381,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -4,6 +4,9 @@ Model-free online RL with DIffusion POlicy (DIPO)
Applies action gradient to perturb actions towards maximizer of Q-function. Applies action gradient to perturb actions towards maximizer of Q-function.
a_t <- a_t + \eta * \grad_a Q(s, a) a_t <- a_t + \eta * \grad_a Q(s, a)
Do not support pixel input right now.
""" """
import os import os
@ -90,6 +93,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -104,9 +108,10 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -114,7 +119,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
@ -124,11 +128,14 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode, deterministic=eval_mode,
) )
.cpu() .cpu()
@ -144,8 +151,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
# add to buffer # add to buffer
for i in range(self.n_envs): for i in range(self.n_envs):
obs_buffer.append(prev_obs_venv[i]) obs_buffer.append(prev_obs_venv["state"][i])
next_obs_buffer.append(obs_venv[i]) next_obs_buffer.append(obs_venv["state"][i])
action_buffer.append(action_venv[i]) action_buffer.append(action_venv[i])
reward_buffer.append(reward_venv[i] * self.scale_reward_factor) reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
done_buffer.append(done_venv[i]) done_buffer.append(done_venv[i])
@ -191,8 +198,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode: if not eval_mode:
num_batch = self.replay_ratio num_batch = self.replay_ratio
# Critic learning # Critic learning
@ -231,7 +238,12 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Update critic # Update critic
loss_critic = self.model.loss_critic( loss_critic = self.model.loss_critic(
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma {"state": obs_b},
{"state": next_obs_b},
actions_b,
rewards_b,
dones_b,
self.gamma,
) )
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
loss_critic.backward() loss_critic.backward()
@ -239,7 +251,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
# Actor learning # Actor learning
for _ in range(num_batch): for _ in range(num_batch):
# Sample batch # Sample batch
inds = np.random.choice(len(obs_buffer), self.batch_size) inds = np.random.choice(len(obs_buffer), self.batch_size)
obs_b = ( obs_b = (
@ -265,7 +276,9 @@ class TrainDIPODiffusionAgent(TrainAgent):
) )
for _ in range(self.action_gradient_steps): for _ in range(self.action_gradient_steps):
actions_flat.requires_grad_(True) actions_flat.requires_grad_(True)
q_values_1, q_values_2 = self.model.critic(obs_b, actions_flat) q_values_1, q_values_2 = self.model.critic(
{"state": obs_b}, actions_flat
)
q_values = torch.min(q_values_1, q_values_2) q_values = torch.min(q_values_1, q_values_2)
action_opt_loss = -q_values.sum() action_opt_loss = -q_values.sum()
@ -291,7 +304,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
) )
# Update policy with collected trajectories # Update policy with collected trajectories
loss = self.model.loss(guided_action.detach(), {0: obs_b}) loss = self.model.loss(guided_action.detach(), {"state": obs_b})
self.actor_optimizer.zero_grad() self.actor_optimizer.zero_grad()
loss.backward() loss.backward()
if self.itr >= self.n_critic_warmup_itr: if self.itr >= self.n_critic_warmup_itr:
@ -316,6 +329,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -336,7 +350,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -352,7 +366,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -5,6 +5,9 @@ Learns a critic Q-function and backprops the expected Q-value to train the actor
pi = argmin L_d(\theta) - \alpha * E[Q(s, a)] pi = argmin L_d(\theta) - \alpha * E[Q(s, a)]
L_d is demonstration loss for regularization L_d is demonstration loss for regularization
Do not support pixel input right now.
""" """
import os import os
@ -38,7 +41,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
lr=cfg.train.actor_lr, lr=cfg.train.actor_lr,
weight_decay=cfg.train.actor_weight_decay, weight_decay=cfg.train.actor_weight_decay,
) )
# use cosine scheduler with linear warmup
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts( self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
self.actor_optimizer, self.actor_optimizer,
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps, first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
@ -88,6 +90,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -102,9 +105,10 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -112,7 +116,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
@ -122,11 +125,14 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode, deterministic=eval_mode,
) )
.cpu() .cpu()
@ -142,8 +148,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
# add to buffer # add to buffer
for i in range(self.n_envs): for i in range(self.n_envs):
obs_buffer.append(prev_obs_venv[i]) obs_buffer.append(prev_obs_venv["state"][i])
next_obs_buffer.append(obs_venv[i]) next_obs_buffer.append(obs_venv["state"][i])
action_buffer.append(action_venv[i]) action_buffer.append(action_venv[i])
reward_buffer.append(reward_venv[i] * self.scale_reward_factor) reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
done_buffer.append(done_venv[i]) done_buffer.append(done_venv[i])
@ -189,8 +195,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode: if not eval_mode:
num_batch = self.replay_ratio num_batch = self.replay_ratio
# Critic learning # Critic learning
@ -229,7 +235,12 @@ class TrainDQLDiffusionAgent(TrainAgent):
# Update critic # Update critic
loss_critic = self.model.loss_critic( loss_critic = self.model.loss_critic(
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma {"state": obs_b},
{"state": next_obs_b},
actions_b,
rewards_b,
dones_b,
self.gamma,
) )
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
loss_critic.backward() loss_critic.backward()
@ -237,19 +248,21 @@ class TrainDQLDiffusionAgent(TrainAgent):
# get the new action and q values # get the new action and q values
samples = self.model.forward_train( samples = self.model.forward_train(
cond=obs_b.to(self.device), cond={"state": obs_b},
deterministic=eval_mode, deterministic=eval_mode,
) )
output_venv = samples # n_env x horizon x act action_venv = samples[:, : self.act_steps] # n_env x horizon x act
action_venv = output_venv[:, : self.act_steps, : self.action_dim] q_values_b = self.model.critic({"state": obs_b}, action_venv)
actions_flat_b = action_venv.reshape(action_venv.shape[0], -1)
q_values_b = self.model.critic(obs_b, actions_flat_b)
q1_new_action, q2_new_action = q_values_b q1_new_action, q2_new_action = q_values_b
# Update policy with collected trajectories # Update policy with collected trajectories
self.actor_optimizer.zero_grad() self.actor_optimizer.zero_grad()
actor_loss = self.model.loss_actor( actor_loss = self.model.loss_actor(
obs_b, actions_b, q1_new_action, q2_new_action, self.eta {"state": obs_b},
actions_b,
q1_new_action,
q2_new_action,
self.eta,
) )
actor_loss.backward() actor_loss.backward()
if self.itr >= self.n_critic_warmup_itr: if self.itr >= self.n_critic_warmup_itr:
@ -275,6 +288,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -295,7 +309,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -311,7 +325,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -1,6 +1,8 @@
""" """
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy. Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
Do not support pixel input right now.
""" """
import os import os
@ -11,7 +13,6 @@ import torch
import logging import logging
import wandb import wandb
from copy import deepcopy from copy import deepcopy
import random
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from util.timer import Timer from util.timer import Timer
@ -98,8 +99,8 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# make a FIFO replay buffer for obs, action, and reward # make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size) obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size) next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size) reward_buffer = deque(maxlen=self.buffer_size)
done_buffer = deque(maxlen=self.buffer_size) done_buffer = deque(maxlen=self.buffer_size)
first_buffer = deque(maxlen=self.buffer_size) first_buffer = deque(maxlen=self.buffer_size)
@ -107,6 +108,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -121,9 +123,10 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -131,7 +134,6 @@ class TrainIDQLDiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
@ -141,11 +143,14 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode and self.eval_deterministic, deterministic=eval_mode and self.eval_deterministic,
num_sample=self.num_sample, num_sample=self.num_sample,
use_expectile_exploration=self.use_expectile_exploration, use_expectile_exploration=self.use_expectile_exploration,
@ -162,9 +167,9 @@ class TrainIDQLDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer # add to buffer
obs_buffer.append(prev_obs_venv) obs_buffer.append(prev_obs_venv["state"])
next_obs_buffer.append(obs_venv["state"])
action_buffer.append(action_venv) action_buffer.append(action_venv)
next_obs_buffer.append(obs_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor) reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv) done_buffer.append(done_venv)
first_buffer.append(firsts_trajs[step]) first_buffer.append(firsts_trajs[step])
@ -209,7 +214,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer)) obs_trajs = np.array(deepcopy(obs_buffer))
@ -257,14 +262,21 @@ class TrainIDQLDiffusionAgent(TrainAgent):
done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device) done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device)
# update critic value function # update critic value function
critic_loss_v = self.model.loss_critic_v(obs_b, actions_b) critic_loss_v = self.model.loss_critic_v(
{"state": obs_b}, actions_b
)
self.critic_v_optimizer.zero_grad() self.critic_v_optimizer.zero_grad()
critic_loss_v.backward() critic_loss_v.backward()
self.critic_v_optimizer.step() self.critic_v_optimizer.step()
# update critic q function # update critic q function
critic_loss_q = self.model.loss_critic_q( critic_loss_q = self.model.loss_critic_q(
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma {"state": obs_b},
{"state": next_obs_b},
actions_b,
reward_b,
done_b,
self.gamma,
) )
self.critic_q_optimizer.zero_grad() self.critic_q_optimizer.zero_grad()
critic_loss_q.backward() critic_loss_q.backward()
@ -278,7 +290,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
# Update policy with collected trajectories - no weighting # Update policy with collected trajectories - no weighting
loss = self.model.loss( loss = self.model.loss(
actions_b, actions_b,
obs_b, {"state": obs_b},
) )
self.actor_optimizer.zero_grad() self.actor_optimizer.zero_grad()
loss.backward() loss.backward()
@ -305,6 +317,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -325,7 +338,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -341,7 +354,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -10,6 +10,7 @@ import numpy as np
import torch import torch
import logging import logging
import wandb import wandb
import math
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from util.timer import Timer from util.timer import Timer
@ -78,7 +79,9 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
) )
# Holder # Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
chains_trajs = np.empty( chains_trajs = np.empty(
( (
0, 0,
@ -91,8 +94,8 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2)) (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # remove cond_step dim ) # save current obs
# Collect a set of trajectories from env # Collect a set of trajectories from env
for step in range(self.n_steps): for step in range(self.n_steps):
@ -101,8 +104,13 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model( samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device), cond=cond,
deterministic=eval_mode, deterministic=eval_mode,
return_chain=True, return_chain=True,
) )
@ -118,14 +126,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
if self.save_full_observations: if self.save_full_observations: # state-only
obs_full_venv = np.vstack( obs_full_venv = np.array(
[info["full_obs"][None] for info in info_venv] [info["full_obs"]["state"] for info in info_venv]
) # n_envs x n_act_steps x obs_dim ) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
) )
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
chains_trajs = np.vstack((chains_trajs, chains_venv[None])) chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None]))
@ -177,12 +187,22 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
# Update models # Update models
if not eval_mode: if not eval_mode:
with torch.no_grad(): with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory obs_trajs["state"] = (
obs_t = einops.rearrange( torch.from_numpy(obs_trajs["state"]).float().to(self.device)
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
) )
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs)) values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts: for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten() values = self.model.critic(obs).cpu().numpy().flatten()
@ -219,7 +239,11 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
reward_trajs = reward_trajs_transpose.T reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified # bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad(): with torch.no_grad():
next_value = ( next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -250,10 +274,12 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
returns_trajs = advantages_trajs + values_trajs returns_trajs = advantages_trajs + values_trajs
# k for environment step # k for environment step
obs_k = einops.rearrange( obs_k = {
torch.tensor(obs_trajs).float().to(self.device), "state": einops.rearrange(
"s e h d -> (s e) h d", obs_trajs["state"],
) "s e ... -> (s e) ...",
)
}
chains_k = einops.rearrange( chains_k = einops.rearrange(
torch.tensor(chains_trajs).float().to(self.device), torch.tensor(chains_trajs).float().to(self.device),
"s e t h d -> (s e) t h d", "s e t h d -> (s e) t h d",
@ -283,7 +309,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b] obs_b = {"state": obs_k["state"][inds_b]}
chains_b = chains_k[inds_b] chains_b = chains_k[inds_b]
returns_b = returns_k[inds_b] returns_b = returns_k[inds_b]
values_b = values_k[inds_b] values_b = values_k[inds_b]
@ -351,7 +377,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
) )
# Plot state trajectories in D3IL # Plot state trajectories (only in D3IL)
if ( if (
self.itr % self.render_freq == 0 self.itr % self.render_freq == 0
and self.n_render > 0 and self.n_render > 0

View File

@ -30,7 +30,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
# Set obs dim - we will save the different obs in batch in a dict # Set obs dim - we will save the different obs in batch in a dict
shape_meta = cfg.shape_meta shape_meta = cfg.shape_meta
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()} self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs}
# Gradient accumulation to deal with large GPU RAM usage # Gradient accumulation to deal with large GPU RAM usage
self.grad_accumulate = cfg.train.grad_accumulate self.grad_accumulate = cfg.train.grad_accumulate
@ -95,7 +95,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
key: torch.from_numpy(prev_obs_venv[key]) key: torch.from_numpy(prev_obs_venv[key])
.float() .float()
.to(self.device) .to(self.device)
for key in self.obs_dims.keys() for key in self.obs_dims
} # batch each type of obs and put into dict } # batch each type of obs and put into dict
samples = self.model( samples = self.model(
cond=cond, cond=cond,
@ -114,7 +114,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
for k in obs_trajs.keys(): for k in obs_trajs:
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None])) obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
chains_trajs = np.vstack((chains_trajs, chains_venv[None])) chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
@ -159,7 +159,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
with torch.no_grad(): with torch.no_grad():
# apply image randomization # apply image randomization
@ -187,7 +187,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
self.n_envs * self.n_steps / self.logprob_batch_size self.n_envs * self.n_steps / self.logprob_batch_size
) )
obs_ts = [{} for _ in range(num_split)] obs_ts = [{} for _ in range(num_split)]
for k in obs_trajs.keys(): for k in obs_trajs:
obs_k = einops.rearrange( obs_k = einops.rearrange(
obs_trajs[k], obs_trajs[k],
"s e ... -> (s e) ...", "s e ... -> (s e) ...",
@ -238,7 +238,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
# bootstrap value with GAE if not done - apply reward scaling with constant if specified # bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = { obs_venv_ts = {
key: torch.from_numpy(obs_venv[key]).float().to(self.device) key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims.keys() for key in self.obs_dims
} }
with torch.no_grad(): with torch.no_grad():
next_value = ( next_value = (
@ -278,7 +278,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
obs_trajs[k], obs_trajs[k],
"s e ... -> (s e) ...", "s e ... -> (s e) ...",
) )
for k in obs_trajs.keys() for k in obs_trajs
} }
chains_k = einops.rearrange( chains_k = einops.rearrange(
torch.tensor(chains_trajs).float().to(self.device), torch.tensor(chains_trajs).float().to(self.device),
@ -309,7 +309,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()} obs_b = {k: obs_k[k][inds_b] for k in obs_k}
chains_b = chains_k[inds_b] chains_b = chains_k[inds_b]
returns_b = returns_k[inds_b] returns_b = returns_k[inds_b]
values_b = values_k[inds_b] values_b = values_k[inds_b]
@ -387,7 +387,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
) )
# Update lr # Update lr, min_sampling_std
if self.itr >= self.n_critic_warmup_itr: if self.itr >= self.n_critic_warmup_itr:
self.actor_lr_scheduler.step() self.actor_lr_scheduler.step()
if self.learn_eta: if self.learn_eta:
@ -407,6 +407,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -427,7 +428,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -462,7 +463,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -1,6 +1,8 @@
""" """
Use diffusion exact likelihood for policy gradient. Use diffusion exact likelihood for policy gradient.
Do not support pixel input yet.
""" """
import os import os
@ -10,6 +12,7 @@ import numpy as np
import torch import torch
import logging import logging
import wandb import wandb
import math
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from util.timer import Timer from util.timer import Timer
@ -43,6 +46,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
eval_mode = False
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode last_itr_eval = eval_mode
@ -58,7 +62,9 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
) )
# Holder # Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty( samples_trajs = np.empty(
( (
0, 0,
@ -79,8 +85,8 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2)) (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # remove cond_step dim ) # save current obs
# Collect a set of trajectories from env # Collect a set of trajectories from env
for step in range(self.n_steps): for step in range(self.n_steps):
@ -89,8 +95,13 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model( samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device), cond=cond,
deterministic=eval_mode, deterministic=eval_mode,
return_chain=True, return_chain=True,
) )
@ -101,21 +112,23 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
samples.chains.cpu().numpy() samples.chains.cpu().numpy()
) # n_env x denoising x horizon x act ) # n_env x denoising x horizon x act
action_venv = output_venv[:, : self.act_steps] action_venv = output_venv[:, : self.act_steps]
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
# Apply multi-step action # Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
if self.save_full_observations: if self.save_full_observations: # state-only
obs_full_venv = np.vstack( obs_full_venv = np.array(
[info["full_obs"][None] for info in info_venv] [info["full_obs"]["state"] for info in info_venv]
) # n_envs x n_act_steps x obs_dim ) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
) )
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None])) obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
chains_trajs = np.vstack((chains_trajs, chains_venv[None])) chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None]))
firsts_trajs[step + 1] = done_venv firsts_trajs[step + 1] = done_venv
@ -158,15 +171,25 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
with torch.no_grad(): with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory obs_trajs["state"] = (
obs_t = einops.rearrange( torch.from_numpy(obs_trajs["state"]).float().to(self.device)
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
) )
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs)) values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts: for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten() values = self.model.critic(obs).cpu().numpy().flatten()
@ -193,7 +216,11 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
reward_trajs = reward_trajs_transpose.T reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified # bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad(): with torch.no_grad():
next_value = ( next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -224,10 +251,12 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
returns_trajs = advantages_trajs + values_trajs returns_trajs = advantages_trajs + values_trajs
# k for environment step # k for environment step
obs_k = einops.rearrange( obs_k = {
torch.tensor(obs_trajs).float().to(self.device), "state": einops.rearrange(
"s e h d -> (s e) h d", obs_trajs["state"],
) "s e ... -> (s e) ...",
)
}
samples_k = einops.rearrange( samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device), torch.tensor(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d", "s e h d -> (s e) h d",
@ -257,7 +286,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b] obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b] samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b] returns_b = returns_k[inds_b]
values_b = values_k[inds_b] values_b = values_k[inds_b]
@ -318,7 +347,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
) )
# Plot state trajectories # Plot state trajectories (only in D3IL)
if ( if (
self.itr % self.render_freq == 0 self.itr % self.render_freq == 0
and self.n_render > 0 and self.n_render > 0
@ -354,6 +383,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["chains_trajs"] = chains_trajs run_results[-1]["chains_trajs"] = chains_trajs
run_results[-1]["reward_trajs"] = reward_trajs run_results[-1]["reward_trajs"] = reward_trajs
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -374,7 +404,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -399,7 +429,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -10,6 +10,7 @@ import numpy as np
import torch import torch
import logging import logging
import wandb import wandb
import math
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
from util.timer import Timer from util.timer import Timer
@ -55,7 +56,9 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
) )
# Holder # Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty( samples_trajs = np.empty(
( (
0, 0,
@ -67,8 +70,8 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim)) obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, prev_obs_venv[None].squeeze(2)) (obs_full_trajs, prev_obs_venv["state"][:, -1][None])
) # remove cond_step dim ) # save current obs
# Collect a set of trajectories from env # Collect a set of trajectories from env
for step in range(self.n_steps): for step in range(self.n_steps):
@ -77,26 +80,33 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model( samples = self.model(
cond=torch.from_numpy(prev_obs_venv).float().to(self.device), cond=cond,
deterministic=eval_mode, deterministic=eval_mode,
) )
output_venv = samples.cpu().numpy() output_venv = samples.cpu().numpy()
action_venv = output_venv[:, : self.act_steps, : self.action_dim] action_venv = output_venv[:, : self.act_steps]
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
# Apply multi-step action # Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
if self.save_full_observations: if self.save_full_observations: # state-only
obs_full_venv = np.vstack( obs_full_venv = np.array(
[info["full_obs"][None] for info in info_venv] [info["full_obs"]["state"] for info in info_venv]
) # n_envs x n_act_steps x obs_dim ) # n_envs x act_steps x obs_dim
obs_full_trajs = np.vstack( obs_full_trajs = np.vstack(
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2)) (obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
) )
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
dones_trajs = np.vstack((dones_trajs, done_venv[None])) dones_trajs = np.vstack((dones_trajs, done_venv[None]))
firsts_trajs[step + 1] = done_venv firsts_trajs[step + 1] = done_venv
@ -144,15 +154,25 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
with torch.no_grad(): with torch.no_grad():
# Calculate value and logprobs - split into batches to prevent out of memory obs_trajs["state"] = (
obs_t = einops.rearrange( torch.from_numpy(obs_trajs["state"]).float().to(self.device)
torch.from_numpy(obs_trajs).float().to(self.device),
"s e h d -> (s e) h d",
) )
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
# Calculate value and logprobs - split into batches to prevent out of memory
num_split = math.ceil(
self.n_envs * self.n_steps / self.logprob_batch_size
)
obs_ts = [{} for _ in range(num_split)]
obs_k = einops.rearrange(
obs_trajs["state"],
"s e ... -> (s e) ...",
)
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
for i, obs_t in enumerate(obs_ts_k):
obs_ts[i]["state"] = obs_t
values_trajs = np.empty((0, self.n_envs)) values_trajs = np.empty((0, self.n_envs))
for obs in obs_ts: for obs in obs_ts:
values = self.model.critic(obs).cpu().numpy().flatten() values = self.model.critic(obs).cpu().numpy().flatten()
@ -184,7 +204,11 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
reward_trajs = reward_trajs_transpose.T reward_trajs = reward_trajs_transpose.T
# bootstrap value with GAE if not done - apply reward scaling with constant if specified # bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device) obs_venv_ts = {
"state": torch.from_numpy(obs_venv["state"])
.float()
.to(self.device)
}
with torch.no_grad(): with torch.no_grad():
next_value = ( next_value = (
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy() self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
@ -215,10 +239,12 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
returns_trajs = advantages_trajs + values_trajs returns_trajs = advantages_trajs + values_trajs
# k for environment step # k for environment step
obs_k = einops.rearrange( obs_k = {
torch.tensor(obs_trajs).float().to(self.device), "state": einops.rearrange(
"s e h d -> (s e) h d", obs_trajs["state"],
) "s e ... -> (s e) ...",
)
}
samples_k = einops.rearrange( samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device), torch.tensor(samples_trajs).float().to(self.device),
"s e h d -> (s e) h d", "s e h d -> (s e) h d",
@ -250,7 +276,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b] obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b] samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b] returns_b = returns_k[inds_b]
values_b = values_k[inds_b] values_b = values_k[inds_b]
@ -313,7 +339,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
) )
# Plot state trajectories # Plot state trajectories (only in D3IL)
if ( if (
self.itr % self.render_freq == 0 self.itr % self.render_freq == 0
and self.n_render > 0 and self.n_render > 0

View File

@ -94,8 +94,8 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
key: torch.from_numpy(prev_obs_venv[key]) key: torch.from_numpy(prev_obs_venv[key])
.float() .float()
.to(self.device) .to(self.device)
for key in self.obs_dims.keys() for key in self.obs_dims
} # batch each type of obs and put into dict }
samples = self.model( samples = self.model(
cond=cond, cond=cond,
deterministic=eval_mode, deterministic=eval_mode,
@ -107,7 +107,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
for k in obs_trajs.keys(): for k in obs_trajs:
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None])) obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
samples_trajs = np.vstack((samples_trajs, output_venv[None])) samples_trajs = np.vstack((samples_trajs, output_venv[None]))
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
@ -152,7 +152,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
with torch.no_grad(): with torch.no_grad():
# apply image randomization # apply image randomization
@ -180,7 +180,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
self.n_envs * self.n_steps / self.logprob_batch_size self.n_envs * self.n_steps / self.logprob_batch_size
) )
obs_ts = [{} for _ in range(num_split)] obs_ts = [{} for _ in range(num_split)]
for k in obs_trajs.keys(): for k in obs_trajs:
obs_k = einops.rearrange( obs_k = einops.rearrange(
obs_trajs[k], obs_trajs[k],
"s e ... -> (s e) ...", "s e ... -> (s e) ...",
@ -226,7 +226,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
# bootstrap value with GAE if not done - apply reward scaling with constant if specified # bootstrap value with GAE if not done - apply reward scaling with constant if specified
obs_venv_ts = { obs_venv_ts = {
key: torch.from_numpy(obs_venv[key]).float().to(self.device) key: torch.from_numpy(obs_venv[key]).float().to(self.device)
for key in self.obs_dims.keys() for key in self.obs_dims
} }
with torch.no_grad(): with torch.no_grad():
next_value = ( next_value = (
@ -266,7 +266,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
obs_trajs[k], obs_trajs[k],
"s e ... -> (s e) ...", "s e ... -> (s e) ...",
) )
for k in obs_trajs.keys() for k in obs_trajs
} }
samples_k = einops.rearrange( samples_k = einops.rearrange(
torch.tensor(samples_trajs).float().to(self.device), torch.tensor(samples_trajs).float().to(self.device),
@ -297,7 +297,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()} obs_b = {k: obs_k[k][inds_b] for k in obs_k}
samples_b = samples_k[inds_b] samples_b = samples_k[inds_b]
returns_b = returns_k[inds_b] returns_b = returns_k[inds_b]
values_b = values_k[inds_b] values_b = values_k[inds_b]
@ -383,6 +383,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -403,7 +404,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -437,7 +438,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
run_results[-1]["clip_frac"] = np.mean(clipfracs) run_results[-1]["clip_frac"] = np.mean(clipfracs)
run_results[-1]["explained_variance"] = explained_var run_results[-1]["explained_variance"] = explained_var
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -1,6 +1,8 @@
""" """
QSM (Q-Score Matching) for diffusion policy. QSM (Q-Score Matching) for diffusion policy.
Do not support pixel input right now.
""" """
import os import os
@ -75,8 +77,8 @@ class TrainQSMDiffusionAgent(TrainAgent):
# make a FIFO replay buffer for obs, action, and reward # make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size) obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size) next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size) reward_buffer = deque(maxlen=self.buffer_size)
done_buffer = deque(maxlen=self.buffer_size) done_buffer = deque(maxlen=self.buffer_size)
first_buffer = deque(maxlen=self.buffer_size) first_buffer = deque(maxlen=self.buffer_size)
@ -84,6 +86,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -98,9 +101,10 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -108,7 +112,6 @@ class TrainQSMDiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs)) reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
@ -118,11 +121,14 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode, deterministic=eval_mode,
) )
.cpu() .cpu()
@ -137,9 +143,9 @@ class TrainQSMDiffusionAgent(TrainAgent):
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
# add to buffer # add to buffer
obs_buffer.append(prev_obs_venv) obs_buffer.append(prev_obs_venv["state"])
next_obs_buffer.append(obs_venv["state"])
action_buffer.append(action_venv) action_buffer.append(action_venv)
next_obs_buffer.append(obs_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor) reward_buffer.append(reward_venv * self.scale_reward_factor)
done_buffer.append(done_venv) done_buffer.append(done_venv)
first_buffer.append(firsts_trajs[step]) first_buffer.append(firsts_trajs[step])
@ -184,7 +190,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
obs_trajs = np.array(deepcopy(obs_buffer)) obs_trajs = np.array(deepcopy(obs_buffer))
@ -233,7 +239,12 @@ class TrainQSMDiffusionAgent(TrainAgent):
# update critic q function # update critic q function
critic_loss = self.model.loss_critic( critic_loss = self.model.loss_critic(
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma {"state": obs_b},
{"state": next_obs_b},
actions_b,
reward_b,
done_b,
self.gamma,
) )
self.critic_optimizer.zero_grad() self.critic_optimizer.zero_grad()
critic_loss.backward() critic_loss.backward()
@ -246,7 +257,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
# Update policy with collected trajectories # Update policy with collected trajectories
loss = self.model.loss_actor( loss = self.model.loss_actor(
obs_b, {"state": obs_b},
actions_b, actions_b,
self.q_grad_coeff, self.q_grad_coeff,
) )
@ -274,6 +285,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -294,7 +306,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -310,7 +322,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["loss_critic"] = loss_critic run_results[-1]["loss_critic"] = loss_critic
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -1,6 +1,8 @@
""" """
Reward-weighted regression (RWR) for diffusion policy. Reward-weighted regression (RWR) for diffusion policy.
Do not support pixel input right now.
""" """
import os import os
@ -54,6 +56,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Start training loop # Start training loop
timer = Timer() timer = Timer()
run_results = [] run_results = []
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs)) done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr: while self.itr < self.n_train_itr:
@ -68,9 +71,10 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Define train or eval - all envs restart # Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train() self.model.eval() if eval_mode else self.model.train()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs)) last_itr_eval = eval_mode
# Reset env at the beginning of an iteration # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval: if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv) prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1 firsts_trajs[0] = 1
@ -78,11 +82,11 @@ class TrainRWRDiffusionAgent(TrainAgent):
firsts_trajs[0] = ( firsts_trajs[0] = (
done_venv # if done at the end of last iteration, then the envs are just reset done_venv # if done at the end of last iteration, then the envs are just reset
) )
last_itr_eval = eval_mode
reward_trajs = np.empty((0, self.n_envs))
# Holders # Holder
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim)) obs_trajs = {
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
}
samples_trajs = np.empty( samples_trajs = np.empty(
( (
0, 0,
@ -91,6 +95,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
self.action_dim, self.action_dim,
) )
) )
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env # Collect a set of trajectories from env
for step in range(self.n_steps): for step in range(self.n_steps):
@ -99,24 +104,29 @@ class TrainRWRDiffusionAgent(TrainAgent):
# Select action # Select action
with torch.no_grad(): with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = ( samples = (
self.model( self.model(
cond=torch.from_numpy(prev_obs_venv) cond=cond,
.float()
.to(self.device),
deterministic=eval_mode, deterministic=eval_mode,
) )
.cpu() .cpu()
.numpy() .numpy()
) # n_env x horizon x act ) # n_env x horizon x act
action_venv = samples[:, : self.act_steps] action_venv = samples[:, : self.act_steps]
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
samples_trajs = np.vstack((samples_trajs, samples[None])) samples_trajs = np.vstack((samples_trajs, samples[None]))
# Apply multi-step action # Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step( obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
action_venv action_venv
) )
obs_trajs["state"] = np.vstack(
(obs_trajs["state"], prev_obs_venv["state"][None])
)
reward_trajs = np.vstack((reward_trajs, reward_venv[None])) reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv prev_obs_venv = obs_venv
@ -133,7 +143,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
if len(episodes_start_end) > 0: if len(episodes_start_end) > 0:
# Compute transitions for completed trajectories # Compute transitions for completed trajectories
obs_trajs_split = [ obs_trajs_split = [
obs_trajs[start : end + 1, env_ind] {"state": obs_trajs["state"][start : end + 1, env_ind]}
for env_ind, start, end in episodes_start_end for env_ind, start, end in episodes_start_end
] ]
samples_trajs_split = [ samples_trajs_split = [
@ -183,17 +193,20 @@ class TrainRWRDiffusionAgent(TrainAgent):
success_rate = 0 success_rate = 0
log.info("[WARNING] No episode completed within the iteration!") log.info("[WARNING] No episode completed within the iteration!")
# Update # Update models
if not eval_mode: if not eval_mode:
# Tensorize data and put them to device # Tensorize data and put them to device
# k for environment step # k for environment step
obs_k = ( obs_k = {
torch.tensor(np.concatenate(obs_trajs_split)) "state": torch.tensor(
np.concatenate(
[obs_traj["state"] for obs_traj in obs_trajs_split]
)
)
.float() .float()
.to(self.device) .to(self.device)
) }
samples_k = ( samples_k = (
torch.tensor(np.concatenate(samples_trajs_split)) torch.tensor(np.concatenate(samples_trajs_split))
.float() .float()
@ -204,19 +217,15 @@ class TrainRWRDiffusionAgent(TrainAgent):
returns_trajs_split = ( returns_trajs_split = (
returns_trajs_split - np.mean(returns_trajs_split) returns_trajs_split - np.mean(returns_trajs_split)
) / (returns_trajs_split.std() + 1e-3) ) / (returns_trajs_split.std() + 1e-3)
rewards_k = ( rewards_k = (
torch.tensor(returns_trajs_split) torch.tensor(returns_trajs_split)
.float() .float()
.to(self.device) .to(self.device)
.reshape(-1) .reshape(-1)
) )
rewards_k_scaled = torch.exp(self.beta * rewards_k) rewards_k_scaled = torch.exp(self.beta * rewards_k)
rewards_k_scaled.clamp_(max=self.max_reward_weight) rewards_k_scaled.clamp_(max=self.max_reward_weight)
# rewards_k_scaled = rewards_k_scaled / rewards_k_scaled.mean()
# Update policy and critic # Update policy and critic
total_steps = len(rewards_k_scaled) total_steps = len(rewards_k_scaled)
inds_k = np.arange(total_steps) inds_k = np.arange(total_steps)
@ -229,7 +238,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
start = batch * self.batch_size start = batch * self.batch_size
end = start + self.batch_size end = start + self.batch_size
inds_b = inds_k[start:end] # b for batch inds_b = inds_k[start:end] # b for batch
obs_b = obs_k[inds_b] obs_b = {"state": obs_k["state"][inds_b]}
samples_b = samples_k[inds_b] samples_b = samples_k[inds_b]
rewards_b = rewards_k_scaled[inds_b] rewards_b = rewards_k_scaled[inds_b]
@ -261,6 +270,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
} }
) )
if self.itr % self.log_freq == 0: if self.itr % self.log_freq == 0:
time = timer()
if eval_mode: if eval_mode:
log.info( log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}" f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
@ -281,7 +291,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
run_results[-1]["eval_best_reward"] = avg_best_reward run_results[-1]["eval_best_reward"] = avg_best_reward
else: else:
log.info( log.info(
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}" f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
) )
if self.use_wandb: if self.use_wandb:
wandb.log( wandb.log(
@ -295,7 +305,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
) )
run_results[-1]["loss"] = loss run_results[-1]["loss"] = loss
run_results[-1]["train_episode_reward"] = avg_episode_reward run_results[-1]["train_episode_reward"] = avg_episode_reward
run_results[-1]["time"] = timer() run_results[-1]["time"] = time
with open(self.result_path, "wb") as f: with open(self.result_path, "wb") as f:
pickle.dump(run_results, f) pickle.dump(run_results, f)
self.itr += 1 self.itr += 1

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -105,15 +105,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -99,7 +99,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -100,7 +100,6 @@ model:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
residual_style: True residual_style: True
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False residual_style: False
fixed_std: 0.1 fixed_std: 0.1
num_modes: ${num_modes} num_modes: ${num_modes}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False residual_style: False
fixed_std: 0.1 fixed_std: 0.1
num_modes: ${num_modes} num_modes: ${num_modes}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -54,9 +54,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -47,7 +47,7 @@ model:
residual_style: False residual_style: False
fixed_std: 0.1 fixed_std: 0.1
num_modes: ${num_modes} num_modes: ${num_modes}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -109,15 +109,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -100,10 +100,9 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -107,15 +107,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,15 +95,14 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -58,9 +58,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -58,9 +58,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -56,9 +56,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -57,9 +57,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
critic_v: critic_v:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -99,20 +99,18 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
critic_v: critic_v:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -98,20 +98,18 @@ model:
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -83,20 +83,19 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -81,7 +81,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -90,13 +90,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -84,7 +84,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -93,19 +93,18 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
critic_v: critic_v:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -91,20 +91,18 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -82,7 +82,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -91,13 +91,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -74,7 +74,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -82,6 +82,5 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -87,22 +87,20 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -80,15 +80,14 @@ model:
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
residual_style: True residual_style: True
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -45,7 +45,7 @@ model:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP _target_: model.diffusion.mlp_diffusion.DiffusionMLP
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
cond_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
time_dim: 16 time_dim: 16
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
activation_type: ReLU activation_type: ReLU
@ -55,9 +55,7 @@ model:
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}
ema: ema:

View File

@ -94,13 +94,12 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -95,13 +95,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -94,13 +94,12 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -97,19 +97,18 @@ model:
_target_: model.common.critic.CriticObsAct _target_: model.common.critic.CriticObsAct
action_dim: ${action_dim} action_dim: ${action_dim}
action_steps: ${act_steps} action_steps: ${act_steps}
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
critic_v: critic_v:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
transition_dim: ${transition_dim}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -100,15 +100,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -20,6 +20,7 @@ transition_dim: ${action_dim}
denoising_steps: 100 denoising_steps: 100
ft_denoising_steps: 5 ft_denoising_steps: 5
cond_steps: 1 cond_steps: 1
img_cond_steps: 1
horizon_steps: 4 horizon_steps: 4
act_steps: 4 act_steps: 4
use_ddim: True use_ddim: True
@ -121,6 +122,7 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -133,6 +135,7 @@ model:
time_dim: 32 time_dim: 32
mlp_dims: [512, 512, 512] mlp_dims: [512, 512, 512]
residual_style: True residual_style: True
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
@ -143,6 +146,7 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -150,15 +154,14 @@ model:
num_heads: 4 num_heads: 4
embed_style: embed2 embed_style: embed2
embed_norm: 0 embed_norm: 0
obs_dim: ${obs_dim} img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -103,15 +103,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -108,15 +108,13 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
ft_denoising_steps: ${ft_denoising_steps} ft_denoising_steps: ${ft_denoising_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim} obs_dim: ${obs_dim}
action_dim: ${action_dim} action_dim: ${action_dim}
cond_steps: ${cond_steps}
denoising_steps: ${denoising_steps} denoising_steps: ${denoising_steps}
device: ${device} device: ${device}

View File

@ -94,10 +94,9 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -18,6 +18,7 @@ obs_dim: 9
action_dim: 7 action_dim: 7
transition_dim: ${action_dim} transition_dim: ${action_dim}
cond_steps: 1 cond_steps: 1
img_cond_steps: 1
horizon_steps: 4 horizon_steps: 4
act_steps: 4 act_steps: 4
@ -100,6 +101,7 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -115,6 +117,7 @@ model:
learn_fixed_std: True learn_fixed_std: True
std_min: 0.01 std_min: 0.01
std_max: 0.2 std_max: 0.2
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'} cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
@ -125,6 +128,7 @@ model:
backbone: backbone:
_target_: model.common.vit.VitEncoder _target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape} obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
cfg: cfg:
patch_size: 8 patch_size: 8
depth: 1 depth: 1
@ -132,10 +136,10 @@ model:
num_heads: 4 num_heads: 4
embed_style: embed2 embed_style: embed2
embed_norm: 0 embed_norm: 0
obs_dim: ${obs_dim} img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

View File

@ -26,7 +26,7 @@ env:
name: ${env_name} name: ${env_name}
best_reward_threshold_for_success: 1 best_reward_threshold_for_success: 1
max_episode_steps: 300 max_episode_steps: 300
save_video: false save_video: False
wrappers: wrappers:
robomimic_lowdim: robomimic_lowdim:
normalization_path: ${normalization_path} normalization_path: ${normalization_path}
@ -95,10 +95,9 @@ model:
transition_dim: ${transition_dim} transition_dim: ${transition_dim}
critic: critic:
_target_: model.common.critic.CriticObs _target_: model.common.critic.CriticObs
obs_dim: ${obs_dim}
mlp_dims: [256, 256, 256] mlp_dims: [256, 256, 256]
activation_type: Mish activation_type: Mish
residual_style: True residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps} horizon_steps: ${horizon_steps}
cond_steps: ${cond_steps}
device: ${device} device: ${device}

Some files were not shown because too many files have changed in this diff Show More