squash commits
This commit is contained in:
parent
8ce0aa1485
commit
2ddf63b8f5
@ -16,19 +16,21 @@ import random
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
Batch = namedtuple("Batch", "trajectories conditions")
|
||||
Batch = namedtuple("Batch", "actions conditions")
|
||||
|
||||
|
||||
class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
|
||||
Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
|
||||
|
||||
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
|
||||
(tuple of) dimension of observation, action, images, etc.
|
||||
Use the first max_n_episodes episodes (instead of random sampling)
|
||||
|
||||
Example:
|
||||
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
|
||||
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
||||
Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
||||
|
||||
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -36,23 +38,30 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||
dataset_path,
|
||||
horizon_steps=64,
|
||||
cond_steps=1,
|
||||
img_cond_steps=1,
|
||||
max_n_episodes=10000,
|
||||
use_img=False,
|
||||
device="cuda:0",
|
||||
):
|
||||
assert (
|
||||
img_cond_steps <= cond_steps
|
||||
), "consider using more cond_steps than img_cond_steps"
|
||||
self.horizon_steps = horizon_steps
|
||||
self.cond_steps = cond_steps
|
||||
self.cond_steps = cond_steps # states (proprio, etc.)
|
||||
self.img_cond_steps = img_cond_steps
|
||||
self.device = device
|
||||
self.use_img = use_img
|
||||
|
||||
# Load dataset to device specified
|
||||
if dataset_path.endswith(".npz"):
|
||||
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
|
||||
else:
|
||||
elif dataset_path.endswith(".pkl"):
|
||||
with open(dataset_path, "rb") as f:
|
||||
dataset = pickle.load(f)
|
||||
traj_lengths = dataset["traj_lengths"] # 1-D array
|
||||
total_num_steps = np.sum(traj_lengths[:max_n_episodes])
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {dataset_path}")
|
||||
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
|
||||
total_num_steps = np.sum(traj_lengths)
|
||||
|
||||
# Set up indices for sampling
|
||||
self.indices = self.make_indices(traj_lengths, horizon_steps)
|
||||
@ -75,35 +84,51 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
start = self.indices[idx]
|
||||
"""
|
||||
repeat states/images if using history observation at the beginning of the episode
|
||||
"""
|
||||
start, num_before_start = self.indices[idx]
|
||||
end = start + self.horizon_steps
|
||||
states = self.states[start:end]
|
||||
states = self.states[(start - num_before_start) : end]
|
||||
actions = self.actions[start:end]
|
||||
states = torch.stack(
|
||||
[
|
||||
states[min(num_before_start - t, 0)]
|
||||
for t in reversed(range(self.cond_steps))
|
||||
]
|
||||
) # more recent is at the end
|
||||
conditions = {"state": states}
|
||||
if self.use_img:
|
||||
images = self.images[start:end]
|
||||
conditions = {
|
||||
1 - self.cond_steps: {"state": states[0], "rgb": images[0]}
|
||||
} # TODO: allow obs history, -1, -2, ...
|
||||
else:
|
||||
conditions = {1 - self.cond_steps: states[0]}
|
||||
images = self.images[(start - num_before_start) : end]
|
||||
images = torch.stack(
|
||||
[
|
||||
images[min(num_before_start - t, 0)]
|
||||
for t in reversed(range(self.img_cond_steps))
|
||||
]
|
||||
)
|
||||
conditions["rgb"] = images
|
||||
batch = Batch(actions, conditions)
|
||||
return batch
|
||||
|
||||
def make_indices(self, traj_lengths, horizon_steps):
|
||||
"""
|
||||
makes indices for sampling from dataset;
|
||||
each index maps to a datapoint
|
||||
each index maps to a datapoint, also save the number of steps before it within the same trajectory
|
||||
"""
|
||||
indices = []
|
||||
cur_traj_index = 0
|
||||
for traj_length in traj_lengths:
|
||||
max_start = cur_traj_index + traj_length - horizon_steps + 1
|
||||
indices += list(range(cur_traj_index, max_start))
|
||||
indices += [
|
||||
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
|
||||
]
|
||||
cur_traj_index += traj_length
|
||||
return indices
|
||||
|
||||
def set_train_val_split(self, train_split):
|
||||
"""Not doing validation right now"""
|
||||
"""
|
||||
Not doing validation right now
|
||||
"""
|
||||
num_train = int(len(self.indices) * train_split)
|
||||
train_indices = random.sample(self.indices, num_train)
|
||||
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
|
||||
|
@ -3,6 +3,8 @@ Advantage-weighted regression (AWR) for diffusion policy.
|
||||
|
||||
Advantage = discounted-reward-to-go - V(s)
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -131,6 +133,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -145,9 +148,10 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -155,7 +159,6 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
@ -165,16 +168,19 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
.cpu()
|
||||
.numpy()
|
||||
) # n_env x horizon x act
|
||||
)
|
||||
action_venv = samples[:, : self.act_steps]
|
||||
|
||||
# Apply multi-step action
|
||||
@ -184,7 +190,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
|
||||
# add to buffer
|
||||
obs_buffer.append(prev_obs_venv)
|
||||
obs_buffer.append(prev_obs_venv["state"])
|
||||
action_buffer.append(action_venv)
|
||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||
done_buffer.append(done_venv)
|
||||
@ -230,59 +236,46 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||
obs_trajs = np.array(deepcopy(obs_buffer)) # assume only state
|
||||
reward_trajs = np.array(deepcopy(reward_buffer))
|
||||
dones_trajs = np.array(deepcopy(done_buffer))
|
||||
|
||||
obs_t = einops.rearrange(
|
||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
)
|
||||
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
|
||||
values_trajs = values_t.reshape(-1, self.n_envs)
|
||||
values_trajs = np.array(
|
||||
self.model.critic({"state": obs_t}).detach().cpu().numpy()
|
||||
).reshape(-1, self.n_envs)
|
||||
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
||||
td_t = torch.from_numpy(td_trajs.flatten()).float().to(self.device)
|
||||
|
||||
# flatten
|
||||
obs_trajs = einops.rearrange(
|
||||
obs_trajs,
|
||||
"s e h d -> (s e) h d",
|
||||
)
|
||||
td_trajs = einops.rearrange(
|
||||
td_trajs,
|
||||
"s e -> (s e)",
|
||||
)
|
||||
|
||||
# Update policy and critic
|
||||
# Update critic
|
||||
num_batch = int(
|
||||
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
|
||||
)
|
||||
for _ in range(num_batch // self.critic_update_ratio):
|
||||
|
||||
# Sample batch
|
||||
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
||||
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
|
||||
td_b = torch.from_numpy(td_trajs[inds]).float().to(self.device)
|
||||
|
||||
# Update critic
|
||||
loss_critic = self.model.loss_critic(obs_b, td_b)
|
||||
loss_critic = self.model.loss_critic(
|
||||
{"state": obs_t[inds]}, td_t[inds]
|
||||
)
|
||||
self.critic_optimizer.zero_grad()
|
||||
loss_critic.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
# Update policy - use a new copy of data
|
||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||
samples_trajs = np.array(deepcopy(action_buffer))
|
||||
reward_trajs = np.array(deepcopy(reward_buffer))
|
||||
dones_trajs = np.array(deepcopy(done_buffer))
|
||||
|
||||
obs_t = einops.rearrange(
|
||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
)
|
||||
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
|
||||
values_trajs = values_t.reshape(-1, self.n_envs)
|
||||
values_trajs = np.array(
|
||||
self.model.critic({"state": obs_t}).detach().cpu().numpy()
|
||||
).reshape(-1, self.n_envs)
|
||||
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
||||
advantages_trajs = td_trajs - values_trajs
|
||||
|
||||
@ -304,7 +297,11 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
|
||||
# Sample batch
|
||||
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
||||
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
|
||||
obs_b = {
|
||||
"state": torch.from_numpy(obs_trajs[inds])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
actions_b = (
|
||||
torch.from_numpy(samples_trajs[inds]).float().to(self.device)
|
||||
)
|
||||
@ -347,6 +344,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -367,7 +365,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -383,7 +381,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["loss_critic"] = loss_critic
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -4,6 +4,9 @@ Model-free online RL with DIffusion POlicy (DIPO)
|
||||
Applies action gradient to perturb actions towards maximizer of Q-function.
|
||||
|
||||
a_t <- a_t + \eta * \grad_a Q(s, a)
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -90,6 +93,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -104,9 +108,10 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -114,7 +119,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
@ -124,11 +128,14 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
.cpu()
|
||||
@ -144,8 +151,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
|
||||
# add to buffer
|
||||
for i in range(self.n_envs):
|
||||
obs_buffer.append(prev_obs_venv[i])
|
||||
next_obs_buffer.append(obs_venv[i])
|
||||
obs_buffer.append(prev_obs_venv["state"][i])
|
||||
next_obs_buffer.append(obs_venv["state"][i])
|
||||
action_buffer.append(action_venv[i])
|
||||
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
||||
done_buffer.append(done_venv[i])
|
||||
@ -191,8 +198,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
num_batch = self.replay_ratio
|
||||
|
||||
# Critic learning
|
||||
@ -231,7 +238,12 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
|
||||
# Update critic
|
||||
loss_critic = self.model.loss_critic(
|
||||
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
|
||||
{"state": obs_b},
|
||||
{"state": next_obs_b},
|
||||
actions_b,
|
||||
rewards_b,
|
||||
dones_b,
|
||||
self.gamma,
|
||||
)
|
||||
self.critic_optimizer.zero_grad()
|
||||
loss_critic.backward()
|
||||
@ -239,7 +251,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
|
||||
# Actor learning
|
||||
for _ in range(num_batch):
|
||||
|
||||
# Sample batch
|
||||
inds = np.random.choice(len(obs_buffer), self.batch_size)
|
||||
obs_b = (
|
||||
@ -265,7 +276,9 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
)
|
||||
for _ in range(self.action_gradient_steps):
|
||||
actions_flat.requires_grad_(True)
|
||||
q_values_1, q_values_2 = self.model.critic(obs_b, actions_flat)
|
||||
q_values_1, q_values_2 = self.model.critic(
|
||||
{"state": obs_b}, actions_flat
|
||||
)
|
||||
q_values = torch.min(q_values_1, q_values_2)
|
||||
action_opt_loss = -q_values.sum()
|
||||
|
||||
@ -291,7 +304,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
)
|
||||
|
||||
# Update policy with collected trajectories
|
||||
loss = self.model.loss(guided_action.detach(), {0: obs_b})
|
||||
loss = self.model.loss(guided_action.detach(), {"state": obs_b})
|
||||
self.actor_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if self.itr >= self.n_critic_warmup_itr:
|
||||
@ -316,6 +329,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -336,7 +350,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -352,7 +366,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["loss_critic"] = loss_critic
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -5,6 +5,9 @@ Learns a critic Q-function and backprops the expected Q-value to train the actor
|
||||
|
||||
pi = argmin L_d(\theta) - \alpha * E[Q(s, a)]
|
||||
L_d is demonstration loss for regularization
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -38,7 +41,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
lr=cfg.train.actor_lr,
|
||||
weight_decay=cfg.train.actor_weight_decay,
|
||||
)
|
||||
# use cosine scheduler with linear warmup
|
||||
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
|
||||
self.actor_optimizer,
|
||||
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
|
||||
@ -88,6 +90,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -102,9 +105,10 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -112,7 +116,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
@ -122,11 +125,14 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
.cpu()
|
||||
@ -142,8 +148,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# add to buffer
|
||||
for i in range(self.n_envs):
|
||||
obs_buffer.append(prev_obs_venv[i])
|
||||
next_obs_buffer.append(obs_venv[i])
|
||||
obs_buffer.append(prev_obs_venv["state"][i])
|
||||
next_obs_buffer.append(obs_venv["state"][i])
|
||||
action_buffer.append(action_venv[i])
|
||||
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
||||
done_buffer.append(done_venv[i])
|
||||
@ -189,8 +195,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
num_batch = self.replay_ratio
|
||||
|
||||
# Critic learning
|
||||
@ -229,7 +235,12 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# Update critic
|
||||
loss_critic = self.model.loss_critic(
|
||||
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
|
||||
{"state": obs_b},
|
||||
{"state": next_obs_b},
|
||||
actions_b,
|
||||
rewards_b,
|
||||
dones_b,
|
||||
self.gamma,
|
||||
)
|
||||
self.critic_optimizer.zero_grad()
|
||||
loss_critic.backward()
|
||||
@ -237,19 +248,21 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# get the new action and q values
|
||||
samples = self.model.forward_train(
|
||||
cond=obs_b.to(self.device),
|
||||
cond={"state": obs_b},
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
output_venv = samples # n_env x horizon x act
|
||||
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
|
||||
actions_flat_b = action_venv.reshape(action_venv.shape[0], -1)
|
||||
q_values_b = self.model.critic(obs_b, actions_flat_b)
|
||||
action_venv = samples[:, : self.act_steps] # n_env x horizon x act
|
||||
q_values_b = self.model.critic({"state": obs_b}, action_venv)
|
||||
q1_new_action, q2_new_action = q_values_b
|
||||
|
||||
# Update policy with collected trajectories
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss = self.model.loss_actor(
|
||||
obs_b, actions_b, q1_new_action, q2_new_action, self.eta
|
||||
{"state": obs_b},
|
||||
actions_b,
|
||||
q1_new_action,
|
||||
q2_new_action,
|
||||
self.eta,
|
||||
)
|
||||
actor_loss.backward()
|
||||
if self.itr >= self.n_critic_warmup_itr:
|
||||
@ -275,6 +288,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -295,7 +309,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -311,7 +325,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["loss_critic"] = loss_critic
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -11,7 +13,6 @@ import torch
|
||||
import logging
|
||||
import wandb
|
||||
from copy import deepcopy
|
||||
import random
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from util.timer import Timer
|
||||
@ -98,8 +99,8 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# make a FIFO replay buffer for obs, action, and reward
|
||||
obs_buffer = deque(maxlen=self.buffer_size)
|
||||
action_buffer = deque(maxlen=self.buffer_size)
|
||||
next_obs_buffer = deque(maxlen=self.buffer_size)
|
||||
action_buffer = deque(maxlen=self.buffer_size)
|
||||
reward_buffer = deque(maxlen=self.buffer_size)
|
||||
done_buffer = deque(maxlen=self.buffer_size)
|
||||
first_buffer = deque(maxlen=self.buffer_size)
|
||||
@ -107,6 +108,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -121,9 +123,10 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -131,7 +134,6 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
@ -141,11 +143,14 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode and self.eval_deterministic,
|
||||
num_sample=self.num_sample,
|
||||
use_expectile_exploration=self.use_expectile_exploration,
|
||||
@ -162,9 +167,9 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
|
||||
# add to buffer
|
||||
obs_buffer.append(prev_obs_venv)
|
||||
obs_buffer.append(prev_obs_venv["state"])
|
||||
next_obs_buffer.append(obs_venv["state"])
|
||||
action_buffer.append(action_venv)
|
||||
next_obs_buffer.append(obs_venv)
|
||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||
done_buffer.append(done_venv)
|
||||
first_buffer.append(firsts_trajs[step])
|
||||
@ -209,7 +214,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||
@ -257,14 +262,21 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device)
|
||||
|
||||
# update critic value function
|
||||
critic_loss_v = self.model.loss_critic_v(obs_b, actions_b)
|
||||
critic_loss_v = self.model.loss_critic_v(
|
||||
{"state": obs_b}, actions_b
|
||||
)
|
||||
self.critic_v_optimizer.zero_grad()
|
||||
critic_loss_v.backward()
|
||||
self.critic_v_optimizer.step()
|
||||
|
||||
# update critic q function
|
||||
critic_loss_q = self.model.loss_critic_q(
|
||||
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
|
||||
{"state": obs_b},
|
||||
{"state": next_obs_b},
|
||||
actions_b,
|
||||
reward_b,
|
||||
done_b,
|
||||
self.gamma,
|
||||
)
|
||||
self.critic_q_optimizer.zero_grad()
|
||||
critic_loss_q.backward()
|
||||
@ -278,7 +290,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
# Update policy with collected trajectories - no weighting
|
||||
loss = self.model.loss(
|
||||
actions_b,
|
||||
obs_b,
|
||||
{"state": obs_b},
|
||||
)
|
||||
self.actor_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
@ -305,6 +317,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -325,7 +338,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -341,7 +354,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["loss_critic"] = loss_critic
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import wandb
|
||||
import math
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from util.timer import Timer
|
||||
@ -78,7 +79,9 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
)
|
||||
|
||||
# Holder
|
||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
obs_trajs = {
|
||||
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
}
|
||||
chains_trajs = np.empty(
|
||||
(
|
||||
0,
|
||||
@ -91,8 +94,8 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
||||
) # remove cond_step dim
|
||||
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||
) # save current obs
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
for step in range(self.n_steps):
|
||||
@ -101,8 +104,13 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
return_chain=True,
|
||||
)
|
||||
@ -118,14 +126,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
if self.save_full_observations:
|
||||
obs_full_venv = np.vstack(
|
||||
[info["full_obs"][None] for info in info_venv]
|
||||
) # n_envs x n_act_steps x obs_dim
|
||||
if self.save_full_observations: # state-only
|
||||
obs_full_venv = np.array(
|
||||
[info["full_obs"]["state"] for info in info_venv]
|
||||
) # n_envs x act_steps x obs_dim
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||
)
|
||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
||||
obs_trajs["state"] = np.vstack(
|
||||
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||
)
|
||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||
@ -177,12 +187,22 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
with torch.no_grad():
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
obs_t = einops.rearrange(
|
||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_trajs["state"] = (
|
||||
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||
)
|
||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
||||
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
num_split = math.ceil(
|
||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||
)
|
||||
obs_ts = [{} for _ in range(num_split)]
|
||||
obs_k = einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||
for i, obs_t in enumerate(obs_ts_k):
|
||||
obs_ts[i]["state"] = obs_t
|
||||
values_trajs = np.empty((0, self.n_envs))
|
||||
for obs in obs_ts:
|
||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||
@ -219,7 +239,11 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
reward_trajs = reward_trajs_transpose.T
|
||||
|
||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
||||
obs_venv_ts = {
|
||||
"state": torch.from_numpy(obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
@ -250,10 +274,12 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
returns_trajs = advantages_trajs + values_trajs
|
||||
|
||||
# k for environment step
|
||||
obs_k = einops.rearrange(
|
||||
torch.tensor(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_k = {
|
||||
"state": einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
}
|
||||
chains_k = einops.rearrange(
|
||||
torch.tensor(chains_trajs).float().to(self.device),
|
||||
"s e t h d -> (s e) t h d",
|
||||
@ -283,7 +309,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = obs_k[inds_b]
|
||||
obs_b = {"state": obs_k["state"][inds_b]}
|
||||
chains_b = chains_k[inds_b]
|
||||
returns_b = returns_k[inds_b]
|
||||
values_b = values_k[inds_b]
|
||||
@ -351,7 +377,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
)
|
||||
|
||||
# Plot state trajectories in D3IL
|
||||
# Plot state trajectories (only in D3IL)
|
||||
if (
|
||||
self.itr % self.render_freq == 0
|
||||
and self.n_render > 0
|
||||
|
@ -30,7 +30,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
|
||||
# Set obs dim - we will save the different obs in batch in a dict
|
||||
shape_meta = cfg.shape_meta
|
||||
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()}
|
||||
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs}
|
||||
|
||||
# Gradient accumulation to deal with large GPU RAM usage
|
||||
self.grad_accumulate = cfg.train.grad_accumulate
|
||||
@ -95,7 +95,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
key: torch.from_numpy(prev_obs_venv[key])
|
||||
.float()
|
||||
.to(self.device)
|
||||
for key in self.obs_dims.keys()
|
||||
for key in self.obs_dims
|
||||
} # batch each type of obs and put into dict
|
||||
samples = self.model(
|
||||
cond=cond,
|
||||
@ -114,7 +114,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
for k in obs_trajs.keys():
|
||||
for k in obs_trajs:
|
||||
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
@ -159,7 +159,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
with torch.no_grad():
|
||||
# apply image randomization
|
||||
@ -187,7 +187,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||
)
|
||||
obs_ts = [{} for _ in range(num_split)]
|
||||
for k in obs_trajs.keys():
|
||||
for k in obs_trajs:
|
||||
obs_k = einops.rearrange(
|
||||
obs_trajs[k],
|
||||
"s e ... -> (s e) ...",
|
||||
@ -238,7 +238,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||
obs_venv_ts = {
|
||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||
for key in self.obs_dims.keys()
|
||||
for key in self.obs_dims
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
@ -278,7 +278,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
obs_trajs[k],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
for k in obs_trajs.keys()
|
||||
for k in obs_trajs
|
||||
}
|
||||
chains_k = einops.rearrange(
|
||||
torch.tensor(chains_trajs).float().to(self.device),
|
||||
@ -309,7 +309,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
|
||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
|
||||
chains_b = chains_k[inds_b]
|
||||
returns_b = returns_k[inds_b]
|
||||
values_b = values_k[inds_b]
|
||||
@ -387,7 +387,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
)
|
||||
|
||||
# Update lr
|
||||
# Update lr, min_sampling_std
|
||||
if self.itr >= self.n_critic_warmup_itr:
|
||||
self.actor_lr_scheduler.step()
|
||||
if self.learn_eta:
|
||||
@ -407,6 +407,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -427,7 +428,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -462,7 +463,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||
run_results[-1]["explained_variance"] = explained_var
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Use diffusion exact likelihood for policy gradient.
|
||||
|
||||
Do not support pixel input yet.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -10,6 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import wandb
|
||||
import math
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from util.timer import Timer
|
||||
@ -43,6 +46,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
eval_mode = False
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
@ -58,7 +62,9 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
)
|
||||
|
||||
# Holder
|
||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
obs_trajs = {
|
||||
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
}
|
||||
samples_trajs = np.empty(
|
||||
(
|
||||
0,
|
||||
@ -79,8 +85,8 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
||||
) # remove cond_step dim
|
||||
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||
) # save current obs
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
for step in range(self.n_steps):
|
||||
@ -89,8 +95,13 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
return_chain=True,
|
||||
)
|
||||
@ -101,21 +112,23 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
samples.chains.cpu().numpy()
|
||||
) # n_env x denoising x horizon x act
|
||||
action_venv = output_venv[:, : self.act_steps]
|
||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||
|
||||
# Apply multi-step action
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
if self.save_full_observations:
|
||||
obs_full_venv = np.vstack(
|
||||
[info["full_obs"][None] for info in info_venv]
|
||||
) # n_envs x n_act_steps x obs_dim
|
||||
if self.save_full_observations: # state-only
|
||||
obs_full_venv = np.array(
|
||||
[info["full_obs"]["state"] for info in info_venv]
|
||||
) # n_envs x act_steps x obs_dim
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||
)
|
||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
||||
obs_trajs["state"] = np.vstack(
|
||||
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||
)
|
||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||
firsts_trajs[step + 1] = done_venv
|
||||
@ -158,15 +171,25 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
with torch.no_grad():
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
obs_t = einops.rearrange(
|
||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_trajs["state"] = (
|
||||
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||
)
|
||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
||||
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
num_split = math.ceil(
|
||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||
)
|
||||
obs_ts = [{} for _ in range(num_split)]
|
||||
obs_k = einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||
for i, obs_t in enumerate(obs_ts_k):
|
||||
obs_ts[i]["state"] = obs_t
|
||||
values_trajs = np.empty((0, self.n_envs))
|
||||
for obs in obs_ts:
|
||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||
@ -193,7 +216,11 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
reward_trajs = reward_trajs_transpose.T
|
||||
|
||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
||||
obs_venv_ts = {
|
||||
"state": torch.from_numpy(obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
@ -224,10 +251,12 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
returns_trajs = advantages_trajs + values_trajs
|
||||
|
||||
# k for environment step
|
||||
obs_k = einops.rearrange(
|
||||
torch.tensor(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_k = {
|
||||
"state": einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
}
|
||||
samples_k = einops.rearrange(
|
||||
torch.tensor(samples_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
@ -257,7 +286,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = obs_k[inds_b]
|
||||
obs_b = {"state": obs_k["state"][inds_b]}
|
||||
samples_b = samples_k[inds_b]
|
||||
returns_b = returns_k[inds_b]
|
||||
values_b = values_k[inds_b]
|
||||
@ -318,7 +347,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
)
|
||||
|
||||
# Plot state trajectories
|
||||
# Plot state trajectories (only in D3IL)
|
||||
if (
|
||||
self.itr % self.render_freq == 0
|
||||
and self.n_render > 0
|
||||
@ -354,6 +383,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
run_results[-1]["chains_trajs"] = chains_trajs
|
||||
run_results[-1]["reward_trajs"] = reward_trajs
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -374,7 +404,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -399,7 +429,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||
run_results[-1]["explained_variance"] = explained_var
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
import wandb
|
||||
import math
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from util.timer import Timer
|
||||
@ -55,7 +56,9 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
)
|
||||
|
||||
# Holder
|
||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
obs_trajs = {
|
||||
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
}
|
||||
samples_trajs = np.empty(
|
||||
(
|
||||
0,
|
||||
@ -67,8 +70,8 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
||||
) # remove cond_step dim
|
||||
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||
) # save current obs
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
for step in range(self.n_steps):
|
||||
@ -77,26 +80,33 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
output_venv = samples.cpu().numpy()
|
||||
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
|
||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||
action_venv = output_venv[:, : self.act_steps]
|
||||
|
||||
# Apply multi-step action
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
if self.save_full_observations:
|
||||
obs_full_venv = np.vstack(
|
||||
[info["full_obs"][None] for info in info_venv]
|
||||
) # n_envs x n_act_steps x obs_dim
|
||||
if self.save_full_observations: # state-only
|
||||
obs_full_venv = np.array(
|
||||
[info["full_obs"]["state"] for info in info_venv]
|
||||
) # n_envs x act_steps x obs_dim
|
||||
obs_full_trajs = np.vstack(
|
||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||
)
|
||||
obs_trajs["state"] = np.vstack(
|
||||
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||
)
|
||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||
firsts_trajs[step + 1] = done_venv
|
||||
@ -144,15 +154,25 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
with torch.no_grad():
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
obs_t = einops.rearrange(
|
||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_trajs["state"] = (
|
||||
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||
)
|
||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
||||
|
||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||
num_split = math.ceil(
|
||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||
)
|
||||
obs_ts = [{} for _ in range(num_split)]
|
||||
obs_k = einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||
for i, obs_t in enumerate(obs_ts_k):
|
||||
obs_ts[i]["state"] = obs_t
|
||||
values_trajs = np.empty((0, self.n_envs))
|
||||
for obs in obs_ts:
|
||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||
@ -184,7 +204,11 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
reward_trajs = reward_trajs_transpose.T
|
||||
|
||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
||||
obs_venv_ts = {
|
||||
"state": torch.from_numpy(obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||
@ -215,10 +239,12 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
returns_trajs = advantages_trajs + values_trajs
|
||||
|
||||
# k for environment step
|
||||
obs_k = einops.rearrange(
|
||||
torch.tensor(obs_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
obs_k = {
|
||||
"state": einops.rearrange(
|
||||
obs_trajs["state"],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
}
|
||||
samples_k = einops.rearrange(
|
||||
torch.tensor(samples_trajs).float().to(self.device),
|
||||
"s e h d -> (s e) h d",
|
||||
@ -250,7 +276,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = obs_k[inds_b]
|
||||
obs_b = {"state": obs_k["state"][inds_b]}
|
||||
samples_b = samples_k[inds_b]
|
||||
returns_b = returns_k[inds_b]
|
||||
values_b = values_k[inds_b]
|
||||
@ -313,7 +339,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
)
|
||||
|
||||
# Plot state trajectories
|
||||
# Plot state trajectories (only in D3IL)
|
||||
if (
|
||||
self.itr % self.render_freq == 0
|
||||
and self.n_render > 0
|
||||
|
@ -94,8 +94,8 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
key: torch.from_numpy(prev_obs_venv[key])
|
||||
.float()
|
||||
.to(self.device)
|
||||
for key in self.obs_dims.keys()
|
||||
} # batch each type of obs and put into dict
|
||||
for key in self.obs_dims
|
||||
}
|
||||
samples = self.model(
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
@ -107,7 +107,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
for k in obs_trajs.keys():
|
||||
for k in obs_trajs:
|
||||
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
@ -152,7 +152,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
with torch.no_grad():
|
||||
# apply image randomization
|
||||
@ -180,7 +180,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||
)
|
||||
obs_ts = [{} for _ in range(num_split)]
|
||||
for k in obs_trajs.keys():
|
||||
for k in obs_trajs:
|
||||
obs_k = einops.rearrange(
|
||||
obs_trajs[k],
|
||||
"s e ... -> (s e) ...",
|
||||
@ -226,7 +226,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||
obs_venv_ts = {
|
||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||
for key in self.obs_dims.keys()
|
||||
for key in self.obs_dims
|
||||
}
|
||||
with torch.no_grad():
|
||||
next_value = (
|
||||
@ -266,7 +266,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
obs_trajs[k],
|
||||
"s e ... -> (s e) ...",
|
||||
)
|
||||
for k in obs_trajs.keys()
|
||||
for k in obs_trajs
|
||||
}
|
||||
samples_k = einops.rearrange(
|
||||
torch.tensor(samples_trajs).float().to(self.device),
|
||||
@ -297,7 +297,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
|
||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
|
||||
samples_b = samples_k[inds_b]
|
||||
returns_b = returns_k[inds_b]
|
||||
values_b = values_k[inds_b]
|
||||
@ -383,6 +383,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -403,7 +404,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -437,7 +438,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||
run_results[-1]["explained_variance"] = explained_var
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
QSM (Q-Score Matching) for diffusion policy.
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -75,8 +77,8 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
|
||||
# make a FIFO replay buffer for obs, action, and reward
|
||||
obs_buffer = deque(maxlen=self.buffer_size)
|
||||
action_buffer = deque(maxlen=self.buffer_size)
|
||||
next_obs_buffer = deque(maxlen=self.buffer_size)
|
||||
action_buffer = deque(maxlen=self.buffer_size)
|
||||
reward_buffer = deque(maxlen=self.buffer_size)
|
||||
done_buffer = deque(maxlen=self.buffer_size)
|
||||
first_buffer = deque(maxlen=self.buffer_size)
|
||||
@ -84,6 +86,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -98,9 +101,10 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -108,7 +112,6 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
@ -118,11 +121,14 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
.cpu()
|
||||
@ -137,9 +143,9 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
|
||||
# add to buffer
|
||||
obs_buffer.append(prev_obs_venv)
|
||||
obs_buffer.append(prev_obs_venv["state"])
|
||||
next_obs_buffer.append(obs_venv["state"])
|
||||
action_buffer.append(action_venv)
|
||||
next_obs_buffer.append(obs_venv)
|
||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||
done_buffer.append(done_venv)
|
||||
first_buffer.append(firsts_trajs[step])
|
||||
@ -184,7 +190,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||
@ -233,7 +239,12 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
|
||||
# update critic q function
|
||||
critic_loss = self.model.loss_critic(
|
||||
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
|
||||
{"state": obs_b},
|
||||
{"state": next_obs_b},
|
||||
actions_b,
|
||||
reward_b,
|
||||
done_b,
|
||||
self.gamma,
|
||||
)
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
@ -246,7 +257,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
|
||||
# Update policy with collected trajectories
|
||||
loss = self.model.loss_actor(
|
||||
obs_b,
|
||||
{"state": obs_b},
|
||||
actions_b,
|
||||
self.q_grad_coeff,
|
||||
)
|
||||
@ -274,6 +285,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -294,7 +306,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -310,7 +322,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["loss_critic"] = loss_critic
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Reward-weighted regression (RWR) for diffusion policy.
|
||||
|
||||
Do not support pixel input right now.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -54,6 +56,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
# Start training loop
|
||||
timer = Timer()
|
||||
run_results = []
|
||||
last_itr_eval = False
|
||||
done_venv = np.zeros((1, self.n_envs))
|
||||
while self.itr < self.n_train_itr:
|
||||
|
||||
@ -68,9 +71,10 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
# Define train or eval - all envs restart
|
||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||
self.model.eval() if eval_mode else self.model.train()
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
last_itr_eval = eval_mode
|
||||
|
||||
# Reset env at the beginning of an iteration
|
||||
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||
firsts_trajs[0] = 1
|
||||
@ -78,11 +82,11 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
firsts_trajs[0] = (
|
||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||
)
|
||||
last_itr_eval = eval_mode
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Holders
|
||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
# Holder
|
||||
obs_trajs = {
|
||||
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||
}
|
||||
samples_trajs = np.empty(
|
||||
(
|
||||
0,
|
||||
@ -91,6 +95,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
self.action_dim,
|
||||
)
|
||||
)
|
||||
reward_trajs = np.empty((0, self.n_envs))
|
||||
|
||||
# Collect a set of trajectories from env
|
||||
for step in range(self.n_steps):
|
||||
@ -99,24 +104,29 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
|
||||
# Select action
|
||||
with torch.no_grad():
|
||||
cond = {
|
||||
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||
.float()
|
||||
.to(self.device)
|
||||
}
|
||||
samples = (
|
||||
self.model(
|
||||
cond=torch.from_numpy(prev_obs_venv)
|
||||
.float()
|
||||
.to(self.device),
|
||||
cond=cond,
|
||||
deterministic=eval_mode,
|
||||
)
|
||||
.cpu()
|
||||
.numpy()
|
||||
) # n_env x horizon x act
|
||||
action_venv = samples[:, : self.act_steps]
|
||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
||||
samples_trajs = np.vstack((samples_trajs, samples[None]))
|
||||
|
||||
# Apply multi-step action
|
||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||
action_venv
|
||||
)
|
||||
obs_trajs["state"] = np.vstack(
|
||||
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||
)
|
||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||
firsts_trajs[step + 1] = done_venv
|
||||
prev_obs_venv = obs_venv
|
||||
@ -133,7 +143,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
if len(episodes_start_end) > 0:
|
||||
# Compute transitions for completed trajectories
|
||||
obs_trajs_split = [
|
||||
obs_trajs[start : end + 1, env_ind]
|
||||
{"state": obs_trajs["state"][start : end + 1, env_ind]}
|
||||
for env_ind, start, end in episodes_start_end
|
||||
]
|
||||
samples_trajs_split = [
|
||||
@ -183,17 +193,20 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
success_rate = 0
|
||||
log.info("[WARNING] No episode completed within the iteration!")
|
||||
|
||||
# Update
|
||||
# Update models
|
||||
if not eval_mode:
|
||||
|
||||
# Tensorize data and put them to device
|
||||
# k for environment step
|
||||
obs_k = (
|
||||
torch.tensor(np.concatenate(obs_trajs_split))
|
||||
obs_k = {
|
||||
"state": torch.tensor(
|
||||
np.concatenate(
|
||||
[obs_traj["state"] for obs_traj in obs_trajs_split]
|
||||
)
|
||||
)
|
||||
.float()
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
}
|
||||
samples_k = (
|
||||
torch.tensor(np.concatenate(samples_trajs_split))
|
||||
.float()
|
||||
@ -204,19 +217,15 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
returns_trajs_split = (
|
||||
returns_trajs_split - np.mean(returns_trajs_split)
|
||||
) / (returns_trajs_split.std() + 1e-3)
|
||||
|
||||
rewards_k = (
|
||||
torch.tensor(returns_trajs_split)
|
||||
.float()
|
||||
.to(self.device)
|
||||
.reshape(-1)
|
||||
)
|
||||
|
||||
rewards_k_scaled = torch.exp(self.beta * rewards_k)
|
||||
rewards_k_scaled.clamp_(max=self.max_reward_weight)
|
||||
|
||||
# rewards_k_scaled = rewards_k_scaled / rewards_k_scaled.mean()
|
||||
|
||||
# Update policy and critic
|
||||
total_steps = len(rewards_k_scaled)
|
||||
inds_k = np.arange(total_steps)
|
||||
@ -229,7 +238,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
start = batch * self.batch_size
|
||||
end = start + self.batch_size
|
||||
inds_b = inds_k[start:end] # b for batch
|
||||
obs_b = obs_k[inds_b]
|
||||
obs_b = {"state": obs_k["state"][inds_b]}
|
||||
samples_b = samples_k[inds_b]
|
||||
rewards_b = rewards_k_scaled[inds_b]
|
||||
|
||||
@ -261,6 +270,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
}
|
||||
)
|
||||
if self.itr % self.log_freq == 0:
|
||||
time = timer()
|
||||
if eval_mode:
|
||||
log.info(
|
||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||
@ -281,7 +291,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||
else:
|
||||
log.info(
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||
)
|
||||
if self.use_wandb:
|
||||
wandb.log(
|
||||
@ -295,7 +305,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
||||
)
|
||||
run_results[-1]["loss"] = loss
|
||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||
run_results[-1]["time"] = timer()
|
||||
run_results[-1]["time"] = time
|
||||
with open(self.result_path, "wb") as f:
|
||||
pickle.dump(run_results, f)
|
||||
self.itr += 1
|
||||
|
@ -105,15 +105,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -105,15 +105,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -105,15 +105,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
||||
_target_: model.common.critic.CriticObs
|
||||
mlp_dims: [256, 256, 256]
|
||||
residual_style: True
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -54,9 +54,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -47,7 +47,7 @@ model:
|
||||
residual_style: False
|
||||
fixed_std: 0.1
|
||||
num_modes: ${num_modes}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
|
@ -54,9 +54,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -47,7 +47,7 @@ model:
|
||||
residual_style: False
|
||||
fixed_std: 0.1
|
||||
num_modes: ${num_modes}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
|
@ -54,9 +54,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -47,7 +47,7 @@ model:
|
||||
residual_style: False
|
||||
fixed_std: 0.1
|
||||
num_modes: ${num_modes}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -100,10 +100,9 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -58,9 +58,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -57,9 +57,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -58,9 +58,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -57,9 +57,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -57,9 +57,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -56,9 +56,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -57,9 +57,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -83,20 +83,19 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -90,13 +90,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -93,19 +93,18 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
critic_v:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -99,20 +99,18 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -82,6 +82,5 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -83,20 +83,19 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -90,13 +90,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -93,19 +93,18 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
critic_v:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -98,20 +98,18 @@ model:
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -82,6 +82,5 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -83,20 +83,19 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -90,13 +90,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -93,19 +93,18 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
critic_v:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -91,13 +91,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -82,6 +82,5 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
residual_style: True
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -45,7 +45,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -55,9 +55,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -45,7 +45,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -55,9 +55,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -45,7 +45,7 @@ model:
|
||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
cond_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
time_dim: 16
|
||||
mlp_dims: [512, 512, 512]
|
||||
activation_type: ReLU
|
||||
@ -55,9 +55,7 @@ model:
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
||||
|
||||
ema:
|
||||
|
@ -94,13 +94,12 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -95,13 +95,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -94,13 +94,12 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -97,19 +97,18 @@ model:
|
||||
_target_: model.common.critic.CriticObsAct
|
||||
action_dim: ${action_dim}
|
||||
action_steps: ${act_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
critic_v:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
transition_dim: ${transition_dim}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -100,15 +100,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -20,6 +20,7 @@ transition_dim: ${action_dim}
|
||||
denoising_steps: 100
|
||||
ft_denoising_steps: 5
|
||||
cond_steps: 1
|
||||
img_cond_steps: 1
|
||||
horizon_steps: 4
|
||||
act_steps: 4
|
||||
use_ddim: True
|
||||
@ -121,6 +122,7 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -133,6 +135,7 @@ model:
|
||||
time_dim: 32
|
||||
mlp_dims: [512, 512, 512]
|
||||
residual_style: True
|
||||
img_cond_steps: ${img_cond_steps}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
@ -143,6 +146,7 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -150,15 +154,14 @@ model:
|
||||
num_heads: 4
|
||||
embed_style: embed2
|
||||
embed_norm: 0
|
||||
obs_dim: ${obs_dim}
|
||||
img_cond_steps: ${img_cond_steps}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -103,15 +103,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
ft_denoising_steps: ${ft_denoising_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
horizon_steps: ${horizon_steps}
|
||||
obs_dim: ${obs_dim}
|
||||
action_dim: ${action_dim}
|
||||
cond_steps: ${cond_steps}
|
||||
denoising_steps: ${denoising_steps}
|
||||
device: ${device}
|
@ -94,10 +94,9 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -18,6 +18,7 @@ obs_dim: 9
|
||||
action_dim: 7
|
||||
transition_dim: ${action_dim}
|
||||
cond_steps: 1
|
||||
img_cond_steps: 1
|
||||
horizon_steps: 4
|
||||
act_steps: 4
|
||||
|
||||
@ -100,6 +101,7 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -115,6 +117,7 @@ model:
|
||||
learn_fixed_std: True
|
||||
std_min: 0.01
|
||||
std_max: 0.2
|
||||
img_cond_steps: ${img_cond_steps}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
transition_dim: ${transition_dim}
|
||||
@ -125,6 +128,7 @@ model:
|
||||
backbone:
|
||||
_target_: model.common.vit.VitEncoder
|
||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||
cfg:
|
||||
patch_size: 8
|
||||
depth: 1
|
||||
@ -132,10 +136,10 @@ model:
|
||||
num_heads: 4
|
||||
embed_style: embed2
|
||||
embed_norm: 0
|
||||
obs_dim: ${obs_dim}
|
||||
img_cond_steps: ${img_cond_steps}
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
@ -26,7 +26,7 @@ env:
|
||||
name: ${env_name}
|
||||
best_reward_threshold_for_success: 1
|
||||
max_episode_steps: 300
|
||||
save_video: false
|
||||
save_video: False
|
||||
wrappers:
|
||||
robomimic_lowdim:
|
||||
normalization_path: ${normalization_path}
|
||||
@ -95,10 +95,9 @@ model:
|
||||
transition_dim: ${transition_dim}
|
||||
critic:
|
||||
_target_: model.common.critic.CriticObs
|
||||
obs_dim: ${obs_dim}
|
||||
mlp_dims: [256, 256, 256]
|
||||
activation_type: Mish
|
||||
residual_style: True
|
||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||
horizon_steps: ${horizon_steps}
|
||||
cond_steps: ${cond_steps}
|
||||
device: ${device}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user