squash commits
This commit is contained in:
parent
8ce0aa1485
commit
2ddf63b8f5
@ -16,19 +16,21 @@ import random
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
Batch = namedtuple("Batch", "trajectories conditions")
|
Batch = namedtuple("Batch", "actions conditions")
|
||||||
|
|
||||||
|
|
||||||
class StitchedSequenceDataset(torch.utils.data.Dataset):
|
class StitchedSequenceDataset(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
Dataset to efficiently load and sample trajectories. Stitches episodes together in the time dimension to avoid excessive zero padding. Episode ID's are used to index unique trajectories.
|
Load stitched trajectories of states/actions/images, and 1-D array of traj_lengths, from npz or pkl file.
|
||||||
|
|
||||||
Returns a dictionary with values of shape: [sum_e(T_e), *D] where T_e is traj length of episode e and D is
|
Use the first max_n_episodes episodes (instead of random sampling)
|
||||||
(tuple of) dimension of observation, action, images, etc.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
|
states: [----------traj 1----------][---------traj 2----------] ... [---------traj N----------]
|
||||||
Episode IDs: [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
Episode IDs (determined based on traj_lengths): [---------- 1 ----------][---------- 2 ---------] ... [---------- N ---------]
|
||||||
|
|
||||||
|
Each sample is a namedtuple of (1) chunked actions and (2) a list (obs timesteps) of dictionary with keys states and images.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -36,23 +38,30 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|||||||
dataset_path,
|
dataset_path,
|
||||||
horizon_steps=64,
|
horizon_steps=64,
|
||||||
cond_steps=1,
|
cond_steps=1,
|
||||||
|
img_cond_steps=1,
|
||||||
max_n_episodes=10000,
|
max_n_episodes=10000,
|
||||||
use_img=False,
|
use_img=False,
|
||||||
device="cuda:0",
|
device="cuda:0",
|
||||||
):
|
):
|
||||||
|
assert (
|
||||||
|
img_cond_steps <= cond_steps
|
||||||
|
), "consider using more cond_steps than img_cond_steps"
|
||||||
self.horizon_steps = horizon_steps
|
self.horizon_steps = horizon_steps
|
||||||
self.cond_steps = cond_steps
|
self.cond_steps = cond_steps # states (proprio, etc.)
|
||||||
|
self.img_cond_steps = img_cond_steps
|
||||||
self.device = device
|
self.device = device
|
||||||
self.use_img = use_img
|
self.use_img = use_img
|
||||||
|
|
||||||
# Load dataset to device specified
|
# Load dataset to device specified
|
||||||
if dataset_path.endswith(".npz"):
|
if dataset_path.endswith(".npz"):
|
||||||
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
|
dataset = np.load(dataset_path, allow_pickle=False) # only np arrays
|
||||||
else:
|
elif dataset_path.endswith(".pkl"):
|
||||||
with open(dataset_path, "rb") as f:
|
with open(dataset_path, "rb") as f:
|
||||||
dataset = pickle.load(f)
|
dataset = pickle.load(f)
|
||||||
traj_lengths = dataset["traj_lengths"] # 1-D array
|
else:
|
||||||
total_num_steps = np.sum(traj_lengths[:max_n_episodes])
|
raise ValueError(f"Unsupported file format: {dataset_path}")
|
||||||
|
traj_lengths = dataset["traj_lengths"][:max_n_episodes] # 1-D array
|
||||||
|
total_num_steps = np.sum(traj_lengths)
|
||||||
|
|
||||||
# Set up indices for sampling
|
# Set up indices for sampling
|
||||||
self.indices = self.make_indices(traj_lengths, horizon_steps)
|
self.indices = self.make_indices(traj_lengths, horizon_steps)
|
||||||
@ -75,35 +84,51 @@ class StitchedSequenceDataset(torch.utils.data.Dataset):
|
|||||||
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
|
log.info(f"Images shape/type: {self.images.shape, self.images.dtype}")
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
start = self.indices[idx]
|
"""
|
||||||
|
repeat states/images if using history observation at the beginning of the episode
|
||||||
|
"""
|
||||||
|
start, num_before_start = self.indices[idx]
|
||||||
end = start + self.horizon_steps
|
end = start + self.horizon_steps
|
||||||
states = self.states[start:end]
|
states = self.states[(start - num_before_start) : end]
|
||||||
actions = self.actions[start:end]
|
actions = self.actions[start:end]
|
||||||
|
states = torch.stack(
|
||||||
|
[
|
||||||
|
states[min(num_before_start - t, 0)]
|
||||||
|
for t in reversed(range(self.cond_steps))
|
||||||
|
]
|
||||||
|
) # more recent is at the end
|
||||||
|
conditions = {"state": states}
|
||||||
if self.use_img:
|
if self.use_img:
|
||||||
images = self.images[start:end]
|
images = self.images[(start - num_before_start) : end]
|
||||||
conditions = {
|
images = torch.stack(
|
||||||
1 - self.cond_steps: {"state": states[0], "rgb": images[0]}
|
[
|
||||||
} # TODO: allow obs history, -1, -2, ...
|
images[min(num_before_start - t, 0)]
|
||||||
else:
|
for t in reversed(range(self.img_cond_steps))
|
||||||
conditions = {1 - self.cond_steps: states[0]}
|
]
|
||||||
|
)
|
||||||
|
conditions["rgb"] = images
|
||||||
batch = Batch(actions, conditions)
|
batch = Batch(actions, conditions)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def make_indices(self, traj_lengths, horizon_steps):
|
def make_indices(self, traj_lengths, horizon_steps):
|
||||||
"""
|
"""
|
||||||
makes indices for sampling from dataset;
|
makes indices for sampling from dataset;
|
||||||
each index maps to a datapoint
|
each index maps to a datapoint, also save the number of steps before it within the same trajectory
|
||||||
"""
|
"""
|
||||||
indices = []
|
indices = []
|
||||||
cur_traj_index = 0
|
cur_traj_index = 0
|
||||||
for traj_length in traj_lengths:
|
for traj_length in traj_lengths:
|
||||||
max_start = cur_traj_index + traj_length - horizon_steps + 1
|
max_start = cur_traj_index + traj_length - horizon_steps + 1
|
||||||
indices += list(range(cur_traj_index, max_start))
|
indices += [
|
||||||
|
(i, i - cur_traj_index) for i in range(cur_traj_index, max_start)
|
||||||
|
]
|
||||||
cur_traj_index += traj_length
|
cur_traj_index += traj_length
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def set_train_val_split(self, train_split):
|
def set_train_val_split(self, train_split):
|
||||||
"""Not doing validation right now"""
|
"""
|
||||||
|
Not doing validation right now
|
||||||
|
"""
|
||||||
num_train = int(len(self.indices) * train_split)
|
num_train = int(len(self.indices) * train_split)
|
||||||
train_indices = random.sample(self.indices, num_train)
|
train_indices = random.sample(self.indices, num_train)
|
||||||
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
|
val_indices = [i for i in range(len(self.indices)) if i not in train_indices]
|
||||||
|
@ -3,6 +3,8 @@ Advantage-weighted regression (AWR) for diffusion policy.
|
|||||||
|
|
||||||
Advantage = discounted-reward-to-go - V(s)
|
Advantage = discounted-reward-to-go - V(s)
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -131,6 +133,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -145,9 +148,10 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -155,7 +159,6 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
@ -165,16 +168,19 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy()
|
.numpy()
|
||||||
) # n_env x horizon x act
|
)
|
||||||
action_venv = samples[:, : self.act_steps]
|
action_venv = samples[:, : self.act_steps]
|
||||||
|
|
||||||
# Apply multi-step action
|
# Apply multi-step action
|
||||||
@ -184,7 +190,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
|
|
||||||
# add to buffer
|
# add to buffer
|
||||||
obs_buffer.append(prev_obs_venv)
|
obs_buffer.append(prev_obs_venv["state"])
|
||||||
action_buffer.append(action_venv)
|
action_buffer.append(action_venv)
|
||||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||||
done_buffer.append(done_venv)
|
done_buffer.append(done_venv)
|
||||||
@ -230,59 +236,46 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
obs_trajs = np.array(deepcopy(obs_buffer)) # assume only state
|
||||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
|
||||||
reward_trajs = np.array(deepcopy(reward_buffer))
|
reward_trajs = np.array(deepcopy(reward_buffer))
|
||||||
dones_trajs = np.array(deepcopy(done_buffer))
|
dones_trajs = np.array(deepcopy(done_buffer))
|
||||||
|
|
||||||
obs_t = einops.rearrange(
|
obs_t = einops.rearrange(
|
||||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||||
"s e h d -> (s e) h d",
|
"s e h d -> (s e) h d",
|
||||||
)
|
)
|
||||||
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
|
values_trajs = np.array(
|
||||||
values_trajs = values_t.reshape(-1, self.n_envs)
|
self.model.critic({"state": obs_t}).detach().cpu().numpy()
|
||||||
|
).reshape(-1, self.n_envs)
|
||||||
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
||||||
|
td_t = torch.from_numpy(td_trajs.flatten()).float().to(self.device)
|
||||||
|
|
||||||
# flatten
|
# Update critic
|
||||||
obs_trajs = einops.rearrange(
|
|
||||||
obs_trajs,
|
|
||||||
"s e h d -> (s e) h d",
|
|
||||||
)
|
|
||||||
td_trajs = einops.rearrange(
|
|
||||||
td_trajs,
|
|
||||||
"s e -> (s e)",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update policy and critic
|
|
||||||
num_batch = int(
|
num_batch = int(
|
||||||
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
|
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
|
||||||
)
|
)
|
||||||
for _ in range(num_batch // self.critic_update_ratio):
|
for _ in range(num_batch // self.critic_update_ratio):
|
||||||
|
|
||||||
# Sample batch
|
|
||||||
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
||||||
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
|
loss_critic = self.model.loss_critic(
|
||||||
td_b = torch.from_numpy(td_trajs[inds]).float().to(self.device)
|
{"state": obs_t[inds]}, td_t[inds]
|
||||||
|
)
|
||||||
# Update critic
|
|
||||||
loss_critic = self.model.loss_critic(obs_b, td_b)
|
|
||||||
self.critic_optimizer.zero_grad()
|
self.critic_optimizer.zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
self.critic_optimizer.step()
|
self.critic_optimizer.step()
|
||||||
|
|
||||||
|
# Update policy - use a new copy of data
|
||||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||||
samples_trajs = np.array(deepcopy(action_buffer))
|
samples_trajs = np.array(deepcopy(action_buffer))
|
||||||
reward_trajs = np.array(deepcopy(reward_buffer))
|
reward_trajs = np.array(deepcopy(reward_buffer))
|
||||||
dones_trajs = np.array(deepcopy(done_buffer))
|
dones_trajs = np.array(deepcopy(done_buffer))
|
||||||
|
|
||||||
obs_t = einops.rearrange(
|
obs_t = einops.rearrange(
|
||||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
torch.from_numpy(obs_trajs).float().to(self.device),
|
||||||
"s e h d -> (s e) h d",
|
"s e h d -> (s e) h d",
|
||||||
)
|
)
|
||||||
values_t = np.array(self.model.critic(obs_t).detach().cpu().numpy())
|
values_trajs = np.array(
|
||||||
values_trajs = values_t.reshape(-1, self.n_envs)
|
self.model.critic({"state": obs_t}).detach().cpu().numpy()
|
||||||
|
).reshape(-1, self.n_envs)
|
||||||
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
td_trajs = td_values(obs_trajs, reward_trajs, dones_trajs, values_trajs)
|
||||||
advantages_trajs = td_trajs - values_trajs
|
advantages_trajs = td_trajs - values_trajs
|
||||||
|
|
||||||
@ -304,7 +297,11 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Sample batch
|
# Sample batch
|
||||||
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
inds = np.random.choice(len(obs_trajs), self.batch_size)
|
||||||
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
|
obs_b = {
|
||||||
|
"state": torch.from_numpy(obs_trajs[inds])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
actions_b = (
|
actions_b = (
|
||||||
torch.from_numpy(samples_trajs[inds]).float().to(self.device)
|
torch.from_numpy(samples_trajs[inds]).float().to(self.device)
|
||||||
)
|
)
|
||||||
@ -347,6 +344,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -367,7 +365,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -383,7 +381,7 @@ class TrainAWRDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["loss_critic"] = loss_critic
|
run_results[-1]["loss_critic"] = loss_critic
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -4,6 +4,9 @@ Model-free online RL with DIffusion POlicy (DIPO)
|
|||||||
Applies action gradient to perturb actions towards maximizer of Q-function.
|
Applies action gradient to perturb actions towards maximizer of Q-function.
|
||||||
|
|
||||||
a_t <- a_t + \eta * \grad_a Q(s, a)
|
a_t <- a_t + \eta * \grad_a Q(s, a)
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -90,6 +93,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -104,9 +108,10 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -114,7 +119,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
@ -124,11 +128,14 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
@ -144,8 +151,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# add to buffer
|
# add to buffer
|
||||||
for i in range(self.n_envs):
|
for i in range(self.n_envs):
|
||||||
obs_buffer.append(prev_obs_venv[i])
|
obs_buffer.append(prev_obs_venv["state"][i])
|
||||||
next_obs_buffer.append(obs_venv[i])
|
next_obs_buffer.append(obs_venv["state"][i])
|
||||||
action_buffer.append(action_venv[i])
|
action_buffer.append(action_venv[i])
|
||||||
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
||||||
done_buffer.append(done_venv[i])
|
done_buffer.append(done_venv[i])
|
||||||
@ -191,8 +198,8 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
|
||||||
num_batch = self.replay_ratio
|
num_batch = self.replay_ratio
|
||||||
|
|
||||||
# Critic learning
|
# Critic learning
|
||||||
@ -231,7 +238,12 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Update critic
|
# Update critic
|
||||||
loss_critic = self.model.loss_critic(
|
loss_critic = self.model.loss_critic(
|
||||||
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
|
{"state": obs_b},
|
||||||
|
{"state": next_obs_b},
|
||||||
|
actions_b,
|
||||||
|
rewards_b,
|
||||||
|
dones_b,
|
||||||
|
self.gamma,
|
||||||
)
|
)
|
||||||
self.critic_optimizer.zero_grad()
|
self.critic_optimizer.zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
@ -239,7 +251,6 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Actor learning
|
# Actor learning
|
||||||
for _ in range(num_batch):
|
for _ in range(num_batch):
|
||||||
|
|
||||||
# Sample batch
|
# Sample batch
|
||||||
inds = np.random.choice(len(obs_buffer), self.batch_size)
|
inds = np.random.choice(len(obs_buffer), self.batch_size)
|
||||||
obs_b = (
|
obs_b = (
|
||||||
@ -265,7 +276,9 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
)
|
)
|
||||||
for _ in range(self.action_gradient_steps):
|
for _ in range(self.action_gradient_steps):
|
||||||
actions_flat.requires_grad_(True)
|
actions_flat.requires_grad_(True)
|
||||||
q_values_1, q_values_2 = self.model.critic(obs_b, actions_flat)
|
q_values_1, q_values_2 = self.model.critic(
|
||||||
|
{"state": obs_b}, actions_flat
|
||||||
|
)
|
||||||
q_values = torch.min(q_values_1, q_values_2)
|
q_values = torch.min(q_values_1, q_values_2)
|
||||||
action_opt_loss = -q_values.sum()
|
action_opt_loss = -q_values.sum()
|
||||||
|
|
||||||
@ -291,7 +304,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update policy with collected trajectories
|
# Update policy with collected trajectories
|
||||||
loss = self.model.loss(guided_action.detach(), {0: obs_b})
|
loss = self.model.loss(guided_action.detach(), {"state": obs_b})
|
||||||
self.actor_optimizer.zero_grad()
|
self.actor_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
if self.itr >= self.n_critic_warmup_itr:
|
if self.itr >= self.n_critic_warmup_itr:
|
||||||
@ -316,6 +329,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -336,7 +350,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -352,7 +366,7 @@ class TrainDIPODiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["loss_critic"] = loss_critic
|
run_results[-1]["loss_critic"] = loss_critic
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -5,6 +5,9 @@ Learns a critic Q-function and backprops the expected Q-value to train the actor
|
|||||||
|
|
||||||
pi = argmin L_d(\theta) - \alpha * E[Q(s, a)]
|
pi = argmin L_d(\theta) - \alpha * E[Q(s, a)]
|
||||||
L_d is demonstration loss for regularization
|
L_d is demonstration loss for regularization
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -38,7 +41,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
lr=cfg.train.actor_lr,
|
lr=cfg.train.actor_lr,
|
||||||
weight_decay=cfg.train.actor_weight_decay,
|
weight_decay=cfg.train.actor_weight_decay,
|
||||||
)
|
)
|
||||||
# use cosine scheduler with linear warmup
|
|
||||||
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
|
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
|
||||||
self.actor_optimizer,
|
self.actor_optimizer,
|
||||||
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
|
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
|
||||||
@ -88,6 +90,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -102,9 +105,10 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -112,7 +116,6 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
@ -122,11 +125,14 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
@ -142,8 +148,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# add to buffer
|
# add to buffer
|
||||||
for i in range(self.n_envs):
|
for i in range(self.n_envs):
|
||||||
obs_buffer.append(prev_obs_venv[i])
|
obs_buffer.append(prev_obs_venv["state"][i])
|
||||||
next_obs_buffer.append(obs_venv[i])
|
next_obs_buffer.append(obs_venv["state"][i])
|
||||||
action_buffer.append(action_venv[i])
|
action_buffer.append(action_venv[i])
|
||||||
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
|
||||||
done_buffer.append(done_venv[i])
|
done_buffer.append(done_venv[i])
|
||||||
@ -189,8 +195,8 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
|
||||||
num_batch = self.replay_ratio
|
num_batch = self.replay_ratio
|
||||||
|
|
||||||
# Critic learning
|
# Critic learning
|
||||||
@ -229,7 +235,12 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Update critic
|
# Update critic
|
||||||
loss_critic = self.model.loss_critic(
|
loss_critic = self.model.loss_critic(
|
||||||
obs_b, next_obs_b, actions_b, rewards_b, dones_b, self.gamma
|
{"state": obs_b},
|
||||||
|
{"state": next_obs_b},
|
||||||
|
actions_b,
|
||||||
|
rewards_b,
|
||||||
|
dones_b,
|
||||||
|
self.gamma,
|
||||||
)
|
)
|
||||||
self.critic_optimizer.zero_grad()
|
self.critic_optimizer.zero_grad()
|
||||||
loss_critic.backward()
|
loss_critic.backward()
|
||||||
@ -237,19 +248,21 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# get the new action and q values
|
# get the new action and q values
|
||||||
samples = self.model.forward_train(
|
samples = self.model.forward_train(
|
||||||
cond=obs_b.to(self.device),
|
cond={"state": obs_b},
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
output_venv = samples # n_env x horizon x act
|
action_venv = samples[:, : self.act_steps] # n_env x horizon x act
|
||||||
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
|
q_values_b = self.model.critic({"state": obs_b}, action_venv)
|
||||||
actions_flat_b = action_venv.reshape(action_venv.shape[0], -1)
|
|
||||||
q_values_b = self.model.critic(obs_b, actions_flat_b)
|
|
||||||
q1_new_action, q2_new_action = q_values_b
|
q1_new_action, q2_new_action = q_values_b
|
||||||
|
|
||||||
# Update policy with collected trajectories
|
# Update policy with collected trajectories
|
||||||
self.actor_optimizer.zero_grad()
|
self.actor_optimizer.zero_grad()
|
||||||
actor_loss = self.model.loss_actor(
|
actor_loss = self.model.loss_actor(
|
||||||
obs_b, actions_b, q1_new_action, q2_new_action, self.eta
|
{"state": obs_b},
|
||||||
|
actions_b,
|
||||||
|
q1_new_action,
|
||||||
|
q2_new_action,
|
||||||
|
self.eta,
|
||||||
)
|
)
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
if self.itr >= self.n_critic_warmup_itr:
|
if self.itr >= self.n_critic_warmup_itr:
|
||||||
@ -275,6 +288,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -295,7 +309,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -311,7 +325,7 @@ class TrainDQLDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["loss_critic"] = loss_critic
|
run_results[-1]["loss_critic"] = loss_critic
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
|
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -11,7 +13,6 @@ import torch
|
|||||||
import logging
|
import logging
|
||||||
import wandb
|
import wandb
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import random
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
from util.timer import Timer
|
from util.timer import Timer
|
||||||
@ -98,8 +99,8 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# make a FIFO replay buffer for obs, action, and reward
|
# make a FIFO replay buffer for obs, action, and reward
|
||||||
obs_buffer = deque(maxlen=self.buffer_size)
|
obs_buffer = deque(maxlen=self.buffer_size)
|
||||||
action_buffer = deque(maxlen=self.buffer_size)
|
|
||||||
next_obs_buffer = deque(maxlen=self.buffer_size)
|
next_obs_buffer = deque(maxlen=self.buffer_size)
|
||||||
|
action_buffer = deque(maxlen=self.buffer_size)
|
||||||
reward_buffer = deque(maxlen=self.buffer_size)
|
reward_buffer = deque(maxlen=self.buffer_size)
|
||||||
done_buffer = deque(maxlen=self.buffer_size)
|
done_buffer = deque(maxlen=self.buffer_size)
|
||||||
first_buffer = deque(maxlen=self.buffer_size)
|
first_buffer = deque(maxlen=self.buffer_size)
|
||||||
@ -107,6 +108,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -121,9 +123,10 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -131,7 +134,6 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
@ -141,11 +143,14 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode and self.eval_deterministic,
|
deterministic=eval_mode and self.eval_deterministic,
|
||||||
num_sample=self.num_sample,
|
num_sample=self.num_sample,
|
||||||
use_expectile_exploration=self.use_expectile_exploration,
|
use_expectile_exploration=self.use_expectile_exploration,
|
||||||
@ -162,9 +167,9 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
|
|
||||||
# add to buffer
|
# add to buffer
|
||||||
obs_buffer.append(prev_obs_venv)
|
obs_buffer.append(prev_obs_venv["state"])
|
||||||
|
next_obs_buffer.append(obs_venv["state"])
|
||||||
action_buffer.append(action_venv)
|
action_buffer.append(action_venv)
|
||||||
next_obs_buffer.append(obs_venv)
|
|
||||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||||
done_buffer.append(done_venv)
|
done_buffer.append(done_venv)
|
||||||
first_buffer.append(firsts_trajs[step])
|
first_buffer.append(firsts_trajs[step])
|
||||||
@ -209,7 +214,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
|
||||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||||
@ -257,14 +262,21 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device)
|
done_b = torch.from_numpy(done_trajs[inds]).float().to(self.device)
|
||||||
|
|
||||||
# update critic value function
|
# update critic value function
|
||||||
critic_loss_v = self.model.loss_critic_v(obs_b, actions_b)
|
critic_loss_v = self.model.loss_critic_v(
|
||||||
|
{"state": obs_b}, actions_b
|
||||||
|
)
|
||||||
self.critic_v_optimizer.zero_grad()
|
self.critic_v_optimizer.zero_grad()
|
||||||
critic_loss_v.backward()
|
critic_loss_v.backward()
|
||||||
self.critic_v_optimizer.step()
|
self.critic_v_optimizer.step()
|
||||||
|
|
||||||
# update critic q function
|
# update critic q function
|
||||||
critic_loss_q = self.model.loss_critic_q(
|
critic_loss_q = self.model.loss_critic_q(
|
||||||
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
|
{"state": obs_b},
|
||||||
|
{"state": next_obs_b},
|
||||||
|
actions_b,
|
||||||
|
reward_b,
|
||||||
|
done_b,
|
||||||
|
self.gamma,
|
||||||
)
|
)
|
||||||
self.critic_q_optimizer.zero_grad()
|
self.critic_q_optimizer.zero_grad()
|
||||||
critic_loss_q.backward()
|
critic_loss_q.backward()
|
||||||
@ -278,7 +290,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
# Update policy with collected trajectories - no weighting
|
# Update policy with collected trajectories - no weighting
|
||||||
loss = self.model.loss(
|
loss = self.model.loss(
|
||||||
actions_b,
|
actions_b,
|
||||||
obs_b,
|
{"state": obs_b},
|
||||||
)
|
)
|
||||||
self.actor_optimizer.zero_grad()
|
self.actor_optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -305,6 +317,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -325,7 +338,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -341,7 +354,7 @@ class TrainIDQLDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["loss_critic"] = loss_critic
|
run_results[-1]["loss_critic"] = loss_critic
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import wandb
|
import wandb
|
||||||
|
import math
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
from util.timer import Timer
|
from util.timer import Timer
|
||||||
@ -78,7 +79,9 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Holder
|
# Holder
|
||||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
obs_trajs = {
|
||||||
|
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||||
|
}
|
||||||
chains_trajs = np.empty(
|
chains_trajs = np.empty(
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
@ -91,8 +94,8 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||||
) # remove cond_step dim
|
) # save current obs
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
for step in range(self.n_steps):
|
for step in range(self.n_steps):
|
||||||
@ -101,8 +104,13 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = self.model(
|
samples = self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
cond=cond,
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
return_chain=True,
|
return_chain=True,
|
||||||
)
|
)
|
||||||
@ -118,14 +126,16 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
if self.save_full_observations:
|
if self.save_full_observations: # state-only
|
||||||
obs_full_venv = np.vstack(
|
obs_full_venv = np.array(
|
||||||
[info["full_obs"][None] for info in info_venv]
|
[info["full_obs"]["state"] for info in info_venv]
|
||||||
) # n_envs x n_act_steps x obs_dim
|
) # n_envs x act_steps x obs_dim
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||||
)
|
)
|
||||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
obs_trajs["state"] = np.vstack(
|
||||||
|
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||||
|
)
|
||||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||||
@ -177,12 +187,22 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
# Update models
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
obs_trajs["state"] = (
|
||||||
obs_t = einops.rearrange(
|
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
|
||||||
"s e h d -> (s e) h d",
|
|
||||||
)
|
)
|
||||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
|
||||||
|
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||||
|
num_split = math.ceil(
|
||||||
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||||
|
)
|
||||||
|
obs_ts = [{} for _ in range(num_split)]
|
||||||
|
obs_k = einops.rearrange(
|
||||||
|
obs_trajs["state"],
|
||||||
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||||
|
for i, obs_t in enumerate(obs_ts_k):
|
||||||
|
obs_ts[i]["state"] = obs_t
|
||||||
values_trajs = np.empty((0, self.n_envs))
|
values_trajs = np.empty((0, self.n_envs))
|
||||||
for obs in obs_ts:
|
for obs in obs_ts:
|
||||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||||
@ -219,7 +239,11 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
reward_trajs = reward_trajs_transpose.T
|
reward_trajs = reward_trajs_transpose.T
|
||||||
|
|
||||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
obs_venv_ts = {
|
||||||
|
"state": torch.from_numpy(obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = (
|
next_value = (
|
||||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||||
@ -250,10 +274,12 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
returns_trajs = advantages_trajs + values_trajs
|
returns_trajs = advantages_trajs + values_trajs
|
||||||
|
|
||||||
# k for environment step
|
# k for environment step
|
||||||
obs_k = einops.rearrange(
|
obs_k = {
|
||||||
torch.tensor(obs_trajs).float().to(self.device),
|
"state": einops.rearrange(
|
||||||
"s e h d -> (s e) h d",
|
obs_trajs["state"],
|
||||||
)
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
}
|
||||||
chains_k = einops.rearrange(
|
chains_k = einops.rearrange(
|
||||||
torch.tensor(chains_trajs).float().to(self.device),
|
torch.tensor(chains_trajs).float().to(self.device),
|
||||||
"s e t h d -> (s e) t h d",
|
"s e t h d -> (s e) t h d",
|
||||||
@ -283,7 +309,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = obs_k[inds_b]
|
obs_b = {"state": obs_k["state"][inds_b]}
|
||||||
chains_b = chains_k[inds_b]
|
chains_b = chains_k[inds_b]
|
||||||
returns_b = returns_k[inds_b]
|
returns_b = returns_k[inds_b]
|
||||||
values_b = values_k[inds_b]
|
values_b = values_k[inds_b]
|
||||||
@ -351,7 +377,7 @@ class TrainPPODiffusionAgent(TrainPPOAgent):
|
|||||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot state trajectories in D3IL
|
# Plot state trajectories (only in D3IL)
|
||||||
if (
|
if (
|
||||||
self.itr % self.render_freq == 0
|
self.itr % self.render_freq == 0
|
||||||
and self.n_render > 0
|
and self.n_render > 0
|
||||||
|
@ -30,7 +30,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
|
|
||||||
# Set obs dim - we will save the different obs in batch in a dict
|
# Set obs dim - we will save the different obs in batch in a dict
|
||||||
shape_meta = cfg.shape_meta
|
shape_meta = cfg.shape_meta
|
||||||
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()}
|
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs}
|
||||||
|
|
||||||
# Gradient accumulation to deal with large GPU RAM usage
|
# Gradient accumulation to deal with large GPU RAM usage
|
||||||
self.grad_accumulate = cfg.train.grad_accumulate
|
self.grad_accumulate = cfg.train.grad_accumulate
|
||||||
@ -95,7 +95,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
key: torch.from_numpy(prev_obs_venv[key])
|
key: torch.from_numpy(prev_obs_venv[key])
|
||||||
.float()
|
.float()
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
for key in self.obs_dims.keys()
|
for key in self.obs_dims
|
||||||
} # batch each type of obs and put into dict
|
} # batch each type of obs and put into dict
|
||||||
samples = self.model(
|
samples = self.model(
|
||||||
cond=cond,
|
cond=cond,
|
||||||
@ -114,7 +114,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
for k in obs_trajs.keys():
|
for k in obs_trajs:
|
||||||
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
||||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
@ -159,7 +159,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# apply image randomization
|
# apply image randomization
|
||||||
@ -187,7 +187,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||||
)
|
)
|
||||||
obs_ts = [{} for _ in range(num_split)]
|
obs_ts = [{} for _ in range(num_split)]
|
||||||
for k in obs_trajs.keys():
|
for k in obs_trajs:
|
||||||
obs_k = einops.rearrange(
|
obs_k = einops.rearrange(
|
||||||
obs_trajs[k],
|
obs_trajs[k],
|
||||||
"s e ... -> (s e) ...",
|
"s e ... -> (s e) ...",
|
||||||
@ -238,7 +238,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||||
obs_venv_ts = {
|
obs_venv_ts = {
|
||||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||||
for key in self.obs_dims.keys()
|
for key in self.obs_dims
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = (
|
next_value = (
|
||||||
@ -278,7 +278,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
obs_trajs[k],
|
obs_trajs[k],
|
||||||
"s e ... -> (s e) ...",
|
"s e ... -> (s e) ...",
|
||||||
)
|
)
|
||||||
for k in obs_trajs.keys()
|
for k in obs_trajs
|
||||||
}
|
}
|
||||||
chains_k = einops.rearrange(
|
chains_k = einops.rearrange(
|
||||||
torch.tensor(chains_trajs).float().to(self.device),
|
torch.tensor(chains_trajs).float().to(self.device),
|
||||||
@ -309,7 +309,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
|
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
|
||||||
chains_b = chains_k[inds_b]
|
chains_b = chains_k[inds_b]
|
||||||
returns_b = returns_k[inds_b]
|
returns_b = returns_k[inds_b]
|
||||||
values_b = values_k[inds_b]
|
values_b = values_k[inds_b]
|
||||||
@ -387,7 +387,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update lr
|
# Update lr, min_sampling_std
|
||||||
if self.itr >= self.n_critic_warmup_itr:
|
if self.itr >= self.n_critic_warmup_itr:
|
||||||
self.actor_lr_scheduler.step()
|
self.actor_lr_scheduler.step()
|
||||||
if self.learn_eta:
|
if self.learn_eta:
|
||||||
@ -407,6 +407,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -427,7 +428,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | eta {eta:8.4f} | t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -462,7 +463,7 @@ class TrainPPOImgDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||||
run_results[-1]["explained_variance"] = explained_var
|
run_results[-1]["explained_variance"] = explained_var
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Use diffusion exact likelihood for policy gradient.
|
Use diffusion exact likelihood for policy gradient.
|
||||||
|
|
||||||
|
Do not support pixel input yet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -10,6 +12,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import wandb
|
import wandb
|
||||||
|
import math
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
from util.timer import Timer
|
from util.timer import Timer
|
||||||
@ -43,6 +46,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
|
|
||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
|
eval_mode = False
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
last_itr_eval = eval_mode
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
@ -58,7 +62,9 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Holder
|
# Holder
|
||||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
obs_trajs = {
|
||||||
|
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||||
|
}
|
||||||
samples_trajs = np.empty(
|
samples_trajs = np.empty(
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
@ -79,8 +85,8 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||||
) # remove cond_step dim
|
) # save current obs
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
for step in range(self.n_steps):
|
for step in range(self.n_steps):
|
||||||
@ -89,8 +95,13 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = self.model(
|
samples = self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
cond=cond,
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
return_chain=True,
|
return_chain=True,
|
||||||
)
|
)
|
||||||
@ -101,21 +112,23 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
samples.chains.cpu().numpy()
|
samples.chains.cpu().numpy()
|
||||||
) # n_env x denoising x horizon x act
|
) # n_env x denoising x horizon x act
|
||||||
action_venv = output_venv[:, : self.act_steps]
|
action_venv = output_venv[:, : self.act_steps]
|
||||||
|
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||||
|
|
||||||
# Apply multi-step action
|
# Apply multi-step action
|
||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
if self.save_full_observations:
|
if self.save_full_observations: # state-only
|
||||||
obs_full_venv = np.vstack(
|
obs_full_venv = np.array(
|
||||||
[info["full_obs"][None] for info in info_venv]
|
[info["full_obs"]["state"] for info in info_venv]
|
||||||
) # n_envs x n_act_steps x obs_dim
|
) # n_envs x act_steps x obs_dim
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||||
)
|
)
|
||||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
obs_trajs["state"] = np.vstack(
|
||||||
|
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||||
|
)
|
||||||
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
chains_trajs = np.vstack((chains_trajs, chains_venv[None]))
|
||||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||||
firsts_trajs[step + 1] = done_venv
|
firsts_trajs[step + 1] = done_venv
|
||||||
@ -158,15 +171,25 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
obs_trajs["state"] = (
|
||||||
obs_t = einops.rearrange(
|
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
|
||||||
"s e h d -> (s e) h d",
|
|
||||||
)
|
)
|
||||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
|
||||||
|
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||||
|
num_split = math.ceil(
|
||||||
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||||
|
)
|
||||||
|
obs_ts = [{} for _ in range(num_split)]
|
||||||
|
obs_k = einops.rearrange(
|
||||||
|
obs_trajs["state"],
|
||||||
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||||
|
for i, obs_t in enumerate(obs_ts_k):
|
||||||
|
obs_ts[i]["state"] = obs_t
|
||||||
values_trajs = np.empty((0, self.n_envs))
|
values_trajs = np.empty((0, self.n_envs))
|
||||||
for obs in obs_ts:
|
for obs in obs_ts:
|
||||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||||
@ -193,7 +216,11 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
reward_trajs = reward_trajs_transpose.T
|
reward_trajs = reward_trajs_transpose.T
|
||||||
|
|
||||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
obs_venv_ts = {
|
||||||
|
"state": torch.from_numpy(obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = (
|
next_value = (
|
||||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||||
@ -224,10 +251,12 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
returns_trajs = advantages_trajs + values_trajs
|
returns_trajs = advantages_trajs + values_trajs
|
||||||
|
|
||||||
# k for environment step
|
# k for environment step
|
||||||
obs_k = einops.rearrange(
|
obs_k = {
|
||||||
torch.tensor(obs_trajs).float().to(self.device),
|
"state": einops.rearrange(
|
||||||
"s e h d -> (s e) h d",
|
obs_trajs["state"],
|
||||||
)
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
}
|
||||||
samples_k = einops.rearrange(
|
samples_k = einops.rearrange(
|
||||||
torch.tensor(samples_trajs).float().to(self.device),
|
torch.tensor(samples_trajs).float().to(self.device),
|
||||||
"s e h d -> (s e) h d",
|
"s e h d -> (s e) h d",
|
||||||
@ -257,7 +286,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = obs_k[inds_b]
|
obs_b = {"state": obs_k["state"][inds_b]}
|
||||||
samples_b = samples_k[inds_b]
|
samples_b = samples_k[inds_b]
|
||||||
returns_b = returns_k[inds_b]
|
returns_b = returns_k[inds_b]
|
||||||
values_b = values_k[inds_b]
|
values_b = values_k[inds_b]
|
||||||
@ -318,7 +347,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot state trajectories
|
# Plot state trajectories (only in D3IL)
|
||||||
if (
|
if (
|
||||||
self.itr % self.render_freq == 0
|
self.itr % self.render_freq == 0
|
||||||
and self.n_render > 0
|
and self.n_render > 0
|
||||||
@ -354,6 +383,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
run_results[-1]["chains_trajs"] = chains_trajs
|
run_results[-1]["chains_trajs"] = chains_trajs
|
||||||
run_results[-1]["reward_trajs"] = reward_trajs
|
run_results[-1]["reward_trajs"] = reward_trajs
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -374,7 +404,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -399,7 +429,7 @@ class TrainPPOExactDiffusionAgent(TrainPPODiffusionAgent):
|
|||||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||||
run_results[-1]["explained_variance"] = explained_var
|
run_results[-1]["explained_variance"] = explained_var
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -10,6 +10,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import wandb
|
import wandb
|
||||||
|
import math
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
from util.timer import Timer
|
from util.timer import Timer
|
||||||
@ -55,7 +56,9 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Holder
|
# Holder
|
||||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
obs_trajs = {
|
||||||
|
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||||
|
}
|
||||||
samples_trajs = np.empty(
|
samples_trajs = np.empty(
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
@ -67,8 +70,8 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
obs_full_trajs = np.empty((0, self.n_envs, self.obs_dim))
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, prev_obs_venv[None].squeeze(2))
|
(obs_full_trajs, prev_obs_venv["state"][:, -1][None])
|
||||||
) # remove cond_step dim
|
) # save current obs
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
for step in range(self.n_steps):
|
for step in range(self.n_steps):
|
||||||
@ -77,26 +80,33 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = self.model(
|
samples = self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv).float().to(self.device),
|
cond=cond,
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
output_venv = samples.cpu().numpy()
|
output_venv = samples.cpu().numpy()
|
||||||
action_venv = output_venv[:, : self.act_steps, : self.action_dim]
|
action_venv = output_venv[:, : self.act_steps]
|
||||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
|
||||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
|
||||||
|
|
||||||
# Apply multi-step action
|
# Apply multi-step action
|
||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
if self.save_full_observations:
|
if self.save_full_observations: # state-only
|
||||||
obs_full_venv = np.vstack(
|
obs_full_venv = np.array(
|
||||||
[info["full_obs"][None] for info in info_venv]
|
[info["full_obs"]["state"] for info in info_venv]
|
||||||
) # n_envs x n_act_steps x obs_dim
|
) # n_envs x act_steps x obs_dim
|
||||||
obs_full_trajs = np.vstack(
|
obs_full_trajs = np.vstack(
|
||||||
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
(obs_full_trajs, obs_full_venv.transpose(1, 0, 2))
|
||||||
)
|
)
|
||||||
|
obs_trajs["state"] = np.vstack(
|
||||||
|
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||||
|
)
|
||||||
|
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
dones_trajs = np.vstack((dones_trajs, done_venv[None]))
|
||||||
firsts_trajs[step + 1] = done_venv
|
firsts_trajs[step + 1] = done_venv
|
||||||
@ -144,15 +154,25 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Calculate value and logprobs - split into batches to prevent out of memory
|
obs_trajs["state"] = (
|
||||||
obs_t = einops.rearrange(
|
torch.from_numpy(obs_trajs["state"]).float().to(self.device)
|
||||||
torch.from_numpy(obs_trajs).float().to(self.device),
|
|
||||||
"s e h d -> (s e) h d",
|
|
||||||
)
|
)
|
||||||
obs_ts = torch.split(obs_t, self.logprob_batch_size, dim=0)
|
|
||||||
|
# Calculate value and logprobs - split into batches to prevent out of memory
|
||||||
|
num_split = math.ceil(
|
||||||
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||||
|
)
|
||||||
|
obs_ts = [{} for _ in range(num_split)]
|
||||||
|
obs_k = einops.rearrange(
|
||||||
|
obs_trajs["state"],
|
||||||
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
obs_ts_k = torch.split(obs_k, self.logprob_batch_size, dim=0)
|
||||||
|
for i, obs_t in enumerate(obs_ts_k):
|
||||||
|
obs_ts[i]["state"] = obs_t
|
||||||
values_trajs = np.empty((0, self.n_envs))
|
values_trajs = np.empty((0, self.n_envs))
|
||||||
for obs in obs_ts:
|
for obs in obs_ts:
|
||||||
values = self.model.critic(obs).cpu().numpy().flatten()
|
values = self.model.critic(obs).cpu().numpy().flatten()
|
||||||
@ -184,7 +204,11 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
reward_trajs = reward_trajs_transpose.T
|
reward_trajs = reward_trajs_transpose.T
|
||||||
|
|
||||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||||
obs_venv_ts = torch.from_numpy(obs_venv).float().to(self.device)
|
obs_venv_ts = {
|
||||||
|
"state": torch.from_numpy(obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = (
|
next_value = (
|
||||||
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
self.model.critic(obs_venv_ts).reshape(1, -1).cpu().numpy()
|
||||||
@ -215,10 +239,12 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
returns_trajs = advantages_trajs + values_trajs
|
returns_trajs = advantages_trajs + values_trajs
|
||||||
|
|
||||||
# k for environment step
|
# k for environment step
|
||||||
obs_k = einops.rearrange(
|
obs_k = {
|
||||||
torch.tensor(obs_trajs).float().to(self.device),
|
"state": einops.rearrange(
|
||||||
"s e h d -> (s e) h d",
|
obs_trajs["state"],
|
||||||
)
|
"s e ... -> (s e) ...",
|
||||||
|
)
|
||||||
|
}
|
||||||
samples_k = einops.rearrange(
|
samples_k = einops.rearrange(
|
||||||
torch.tensor(samples_trajs).float().to(self.device),
|
torch.tensor(samples_trajs).float().to(self.device),
|
||||||
"s e h d -> (s e) h d",
|
"s e h d -> (s e) h d",
|
||||||
@ -250,7 +276,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = obs_k[inds_b]
|
obs_b = {"state": obs_k["state"][inds_b]}
|
||||||
samples_b = samples_k[inds_b]
|
samples_b = samples_k[inds_b]
|
||||||
returns_b = returns_k[inds_b]
|
returns_b = returns_k[inds_b]
|
||||||
values_b = values_k[inds_b]
|
values_b = values_k[inds_b]
|
||||||
@ -313,7 +339,7 @@ class TrainPPOGaussianAgent(TrainPPOAgent):
|
|||||||
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot state trajectories
|
# Plot state trajectories (only in D3IL)
|
||||||
if (
|
if (
|
||||||
self.itr % self.render_freq == 0
|
self.itr % self.render_freq == 0
|
||||||
and self.n_render > 0
|
and self.n_render > 0
|
||||||
|
@ -94,8 +94,8 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
key: torch.from_numpy(prev_obs_venv[key])
|
key: torch.from_numpy(prev_obs_venv[key])
|
||||||
.float()
|
.float()
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
for key in self.obs_dims.keys()
|
for key in self.obs_dims
|
||||||
} # batch each type of obs and put into dict
|
}
|
||||||
samples = self.model(
|
samples = self.model(
|
||||||
cond=cond,
|
cond=cond,
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
@ -107,7 +107,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
for k in obs_trajs.keys():
|
for k in obs_trajs:
|
||||||
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
obs_trajs[k] = np.vstack((obs_trajs[k], prev_obs_venv[k][None]))
|
||||||
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
samples_trajs = np.vstack((samples_trajs, output_venv[None]))
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
@ -152,7 +152,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# apply image randomization
|
# apply image randomization
|
||||||
@ -180,7 +180,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
self.n_envs * self.n_steps / self.logprob_batch_size
|
self.n_envs * self.n_steps / self.logprob_batch_size
|
||||||
)
|
)
|
||||||
obs_ts = [{} for _ in range(num_split)]
|
obs_ts = [{} for _ in range(num_split)]
|
||||||
for k in obs_trajs.keys():
|
for k in obs_trajs:
|
||||||
obs_k = einops.rearrange(
|
obs_k = einops.rearrange(
|
||||||
obs_trajs[k],
|
obs_trajs[k],
|
||||||
"s e ... -> (s e) ...",
|
"s e ... -> (s e) ...",
|
||||||
@ -226,7 +226,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
# bootstrap value with GAE if not done - apply reward scaling with constant if specified
|
||||||
obs_venv_ts = {
|
obs_venv_ts = {
|
||||||
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
key: torch.from_numpy(obs_venv[key]).float().to(self.device)
|
||||||
for key in self.obs_dims.keys()
|
for key in self.obs_dims
|
||||||
}
|
}
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_value = (
|
next_value = (
|
||||||
@ -266,7 +266,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
obs_trajs[k],
|
obs_trajs[k],
|
||||||
"s e ... -> (s e) ...",
|
"s e ... -> (s e) ...",
|
||||||
)
|
)
|
||||||
for k in obs_trajs.keys()
|
for k in obs_trajs
|
||||||
}
|
}
|
||||||
samples_k = einops.rearrange(
|
samples_k = einops.rearrange(
|
||||||
torch.tensor(samples_trajs).float().to(self.device),
|
torch.tensor(samples_trajs).float().to(self.device),
|
||||||
@ -297,7 +297,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = {k: obs_k[k][inds_b] for k in obs_k.keys()}
|
obs_b = {k: obs_k[k][inds_b] for k in obs_k}
|
||||||
samples_b = samples_k[inds_b]
|
samples_b = samples_k[inds_b]
|
||||||
returns_b = returns_k[inds_b]
|
returns_b = returns_k[inds_b]
|
||||||
values_b = values_k[inds_b]
|
values_b = values_k[inds_b]
|
||||||
@ -383,6 +383,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -403,7 +404,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | pg loss {pg_loss:8.4f} | value loss {v_loss:8.4f} | bc loss {bc_loss:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -437,7 +438,7 @@ class TrainPPOImgGaussianAgent(TrainPPOGaussianAgent):
|
|||||||
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
run_results[-1]["clip_frac"] = np.mean(clipfracs)
|
||||||
run_results[-1]["explained_variance"] = explained_var
|
run_results[-1]["explained_variance"] = explained_var
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
QSM (Q-Score Matching) for diffusion policy.
|
QSM (Q-Score Matching) for diffusion policy.
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -75,8 +77,8 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# make a FIFO replay buffer for obs, action, and reward
|
# make a FIFO replay buffer for obs, action, and reward
|
||||||
obs_buffer = deque(maxlen=self.buffer_size)
|
obs_buffer = deque(maxlen=self.buffer_size)
|
||||||
action_buffer = deque(maxlen=self.buffer_size)
|
|
||||||
next_obs_buffer = deque(maxlen=self.buffer_size)
|
next_obs_buffer = deque(maxlen=self.buffer_size)
|
||||||
|
action_buffer = deque(maxlen=self.buffer_size)
|
||||||
reward_buffer = deque(maxlen=self.buffer_size)
|
reward_buffer = deque(maxlen=self.buffer_size)
|
||||||
done_buffer = deque(maxlen=self.buffer_size)
|
done_buffer = deque(maxlen=self.buffer_size)
|
||||||
first_buffer = deque(maxlen=self.buffer_size)
|
first_buffer = deque(maxlen=self.buffer_size)
|
||||||
@ -84,6 +86,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -98,9 +101,10 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -108,7 +112,6 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
@ -118,11 +121,14 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
@ -137,9 +143,9 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
|
|
||||||
# add to buffer
|
# add to buffer
|
||||||
obs_buffer.append(prev_obs_venv)
|
obs_buffer.append(prev_obs_venv["state"])
|
||||||
|
next_obs_buffer.append(obs_venv["state"])
|
||||||
action_buffer.append(action_venv)
|
action_buffer.append(action_venv)
|
||||||
next_obs_buffer.append(obs_venv)
|
|
||||||
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
reward_buffer.append(reward_venv * self.scale_reward_factor)
|
||||||
done_buffer.append(done_venv)
|
done_buffer.append(done_venv)
|
||||||
first_buffer.append(firsts_trajs[step])
|
first_buffer.append(firsts_trajs[step])
|
||||||
@ -184,7 +190,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
|
||||||
obs_trajs = np.array(deepcopy(obs_buffer))
|
obs_trajs = np.array(deepcopy(obs_buffer))
|
||||||
@ -233,7 +239,12 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# update critic q function
|
# update critic q function
|
||||||
critic_loss = self.model.loss_critic(
|
critic_loss = self.model.loss_critic(
|
||||||
obs_b, next_obs_b, actions_b, reward_b, done_b, self.gamma
|
{"state": obs_b},
|
||||||
|
{"state": next_obs_b},
|
||||||
|
actions_b,
|
||||||
|
reward_b,
|
||||||
|
done_b,
|
||||||
|
self.gamma,
|
||||||
)
|
)
|
||||||
self.critic_optimizer.zero_grad()
|
self.critic_optimizer.zero_grad()
|
||||||
critic_loss.backward()
|
critic_loss.backward()
|
||||||
@ -246,7 +257,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Update policy with collected trajectories
|
# Update policy with collected trajectories
|
||||||
loss = self.model.loss_actor(
|
loss = self.model.loss_actor(
|
||||||
obs_b,
|
{"state": obs_b},
|
||||||
actions_b,
|
actions_b,
|
||||||
self.q_grad_coeff,
|
self.q_grad_coeff,
|
||||||
)
|
)
|
||||||
@ -274,6 +285,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -294,7 +306,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -310,7 +322,7 @@ class TrainQSMDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["loss_critic"] = loss_critic
|
run_results[-1]["loss_critic"] = loss_critic
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Reward-weighted regression (RWR) for diffusion policy.
|
Reward-weighted regression (RWR) for diffusion policy.
|
||||||
|
|
||||||
|
Do not support pixel input right now.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -54,6 +56,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
# Start training loop
|
# Start training loop
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
run_results = []
|
run_results = []
|
||||||
|
last_itr_eval = False
|
||||||
done_venv = np.zeros((1, self.n_envs))
|
done_venv = np.zeros((1, self.n_envs))
|
||||||
while self.itr < self.n_train_itr:
|
while self.itr < self.n_train_itr:
|
||||||
|
|
||||||
@ -68,9 +71,10 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
# Define train or eval - all envs restart
|
# Define train or eval - all envs restart
|
||||||
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
|
||||||
self.model.eval() if eval_mode else self.model.train()
|
self.model.eval() if eval_mode else self.model.train()
|
||||||
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
last_itr_eval = eval_mode
|
||||||
|
|
||||||
# Reset env at the beginning of an iteration
|
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
|
||||||
|
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
|
||||||
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
if self.reset_at_iteration or eval_mode or last_itr_eval:
|
||||||
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
|
||||||
firsts_trajs[0] = 1
|
firsts_trajs[0] = 1
|
||||||
@ -78,11 +82,11 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
firsts_trajs[0] = (
|
firsts_trajs[0] = (
|
||||||
done_venv # if done at the end of last iteration, then the envs are just reset
|
done_venv # if done at the end of last iteration, then the envs are just reset
|
||||||
)
|
)
|
||||||
last_itr_eval = eval_mode
|
|
||||||
reward_trajs = np.empty((0, self.n_envs))
|
|
||||||
|
|
||||||
# Holders
|
# Holder
|
||||||
obs_trajs = np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
obs_trajs = {
|
||||||
|
"state": np.empty((0, self.n_envs, self.n_cond_step, self.obs_dim))
|
||||||
|
}
|
||||||
samples_trajs = np.empty(
|
samples_trajs = np.empty(
|
||||||
(
|
(
|
||||||
0,
|
0,
|
||||||
@ -91,6 +95,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
self.action_dim,
|
self.action_dim,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
reward_trajs = np.empty((0, self.n_envs))
|
||||||
|
|
||||||
# Collect a set of trajectories from env
|
# Collect a set of trajectories from env
|
||||||
for step in range(self.n_steps):
|
for step in range(self.n_steps):
|
||||||
@ -99,24 +104,29 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
cond = {
|
||||||
|
"state": torch.from_numpy(prev_obs_venv["state"])
|
||||||
|
.float()
|
||||||
|
.to(self.device)
|
||||||
|
}
|
||||||
samples = (
|
samples = (
|
||||||
self.model(
|
self.model(
|
||||||
cond=torch.from_numpy(prev_obs_venv)
|
cond=cond,
|
||||||
.float()
|
|
||||||
.to(self.device),
|
|
||||||
deterministic=eval_mode,
|
deterministic=eval_mode,
|
||||||
)
|
)
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy()
|
.numpy()
|
||||||
) # n_env x horizon x act
|
) # n_env x horizon x act
|
||||||
action_venv = samples[:, : self.act_steps]
|
action_venv = samples[:, : self.act_steps]
|
||||||
obs_trajs = np.vstack((obs_trajs, prev_obs_venv[None]))
|
|
||||||
samples_trajs = np.vstack((samples_trajs, samples[None]))
|
samples_trajs = np.vstack((samples_trajs, samples[None]))
|
||||||
|
|
||||||
# Apply multi-step action
|
# Apply multi-step action
|
||||||
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
|
||||||
action_venv
|
action_venv
|
||||||
)
|
)
|
||||||
|
obs_trajs["state"] = np.vstack(
|
||||||
|
(obs_trajs["state"], prev_obs_venv["state"][None])
|
||||||
|
)
|
||||||
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
|
||||||
firsts_trajs[step + 1] = done_venv
|
firsts_trajs[step + 1] = done_venv
|
||||||
prev_obs_venv = obs_venv
|
prev_obs_venv = obs_venv
|
||||||
@ -133,7 +143,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
if len(episodes_start_end) > 0:
|
if len(episodes_start_end) > 0:
|
||||||
# Compute transitions for completed trajectories
|
# Compute transitions for completed trajectories
|
||||||
obs_trajs_split = [
|
obs_trajs_split = [
|
||||||
obs_trajs[start : end + 1, env_ind]
|
{"state": obs_trajs["state"][start : end + 1, env_ind]}
|
||||||
for env_ind, start, end in episodes_start_end
|
for env_ind, start, end in episodes_start_end
|
||||||
]
|
]
|
||||||
samples_trajs_split = [
|
samples_trajs_split = [
|
||||||
@ -183,17 +193,20 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
success_rate = 0
|
success_rate = 0
|
||||||
log.info("[WARNING] No episode completed within the iteration!")
|
log.info("[WARNING] No episode completed within the iteration!")
|
||||||
|
|
||||||
# Update
|
# Update models
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
|
||||||
# Tensorize data and put them to device
|
# Tensorize data and put them to device
|
||||||
# k for environment step
|
# k for environment step
|
||||||
obs_k = (
|
obs_k = {
|
||||||
torch.tensor(np.concatenate(obs_trajs_split))
|
"state": torch.tensor(
|
||||||
|
np.concatenate(
|
||||||
|
[obs_traj["state"] for obs_traj in obs_trajs_split]
|
||||||
|
)
|
||||||
|
)
|
||||||
.float()
|
.float()
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
)
|
}
|
||||||
|
|
||||||
samples_k = (
|
samples_k = (
|
||||||
torch.tensor(np.concatenate(samples_trajs_split))
|
torch.tensor(np.concatenate(samples_trajs_split))
|
||||||
.float()
|
.float()
|
||||||
@ -204,19 +217,15 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
returns_trajs_split = (
|
returns_trajs_split = (
|
||||||
returns_trajs_split - np.mean(returns_trajs_split)
|
returns_trajs_split - np.mean(returns_trajs_split)
|
||||||
) / (returns_trajs_split.std() + 1e-3)
|
) / (returns_trajs_split.std() + 1e-3)
|
||||||
|
|
||||||
rewards_k = (
|
rewards_k = (
|
||||||
torch.tensor(returns_trajs_split)
|
torch.tensor(returns_trajs_split)
|
||||||
.float()
|
.float()
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
rewards_k_scaled = torch.exp(self.beta * rewards_k)
|
rewards_k_scaled = torch.exp(self.beta * rewards_k)
|
||||||
rewards_k_scaled.clamp_(max=self.max_reward_weight)
|
rewards_k_scaled.clamp_(max=self.max_reward_weight)
|
||||||
|
|
||||||
# rewards_k_scaled = rewards_k_scaled / rewards_k_scaled.mean()
|
|
||||||
|
|
||||||
# Update policy and critic
|
# Update policy and critic
|
||||||
total_steps = len(rewards_k_scaled)
|
total_steps = len(rewards_k_scaled)
|
||||||
inds_k = np.arange(total_steps)
|
inds_k = np.arange(total_steps)
|
||||||
@ -229,7 +238,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
start = batch * self.batch_size
|
start = batch * self.batch_size
|
||||||
end = start + self.batch_size
|
end = start + self.batch_size
|
||||||
inds_b = inds_k[start:end] # b for batch
|
inds_b = inds_k[start:end] # b for batch
|
||||||
obs_b = obs_k[inds_b]
|
obs_b = {"state": obs_k["state"][inds_b]}
|
||||||
samples_b = samples_k[inds_b]
|
samples_b = samples_k[inds_b]
|
||||||
rewards_b = rewards_k_scaled[inds_b]
|
rewards_b = rewards_k_scaled[inds_b]
|
||||||
|
|
||||||
@ -261,6 +270,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
if self.itr % self.log_freq == 0:
|
if self.itr % self.log_freq == 0:
|
||||||
|
time = timer()
|
||||||
if eval_mode:
|
if eval_mode:
|
||||||
log.info(
|
log.info(
|
||||||
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
|
||||||
@ -281,7 +291,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
run_results[-1]["eval_best_reward"] = avg_best_reward
|
run_results[-1]["eval_best_reward"] = avg_best_reward
|
||||||
else:
|
else:
|
||||||
log.info(
|
log.info(
|
||||||
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{timer():8.4f}"
|
f"{self.itr}: loss {loss:8.4f} | reward {avg_episode_reward:8.4f} |t:{time:8.4f}"
|
||||||
)
|
)
|
||||||
if self.use_wandb:
|
if self.use_wandb:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@ -295,7 +305,7 @@ class TrainRWRDiffusionAgent(TrainAgent):
|
|||||||
)
|
)
|
||||||
run_results[-1]["loss"] = loss
|
run_results[-1]["loss"] = loss
|
||||||
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
run_results[-1]["train_episode_reward"] = avg_episode_reward
|
||||||
run_results[-1]["time"] = timer()
|
run_results[-1]["time"] = time
|
||||||
with open(self.result_path, "wb") as f:
|
with open(self.result_path, "wb") as f:
|
||||||
pickle.dump(run_results, f)
|
pickle.dump(run_results, f)
|
||||||
self.itr += 1
|
self.itr += 1
|
||||||
|
@ -105,15 +105,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -105,15 +105,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -105,15 +105,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -99,7 +99,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -100,7 +100,6 @@ model:
|
|||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -54,9 +54,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -47,7 +47,7 @@ model:
|
|||||||
residual_style: False
|
residual_style: False
|
||||||
fixed_std: 0.1
|
fixed_std: 0.1
|
||||||
num_modes: ${num_modes}
|
num_modes: ${num_modes}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
|
@ -54,9 +54,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -47,7 +47,7 @@ model:
|
|||||||
residual_style: False
|
residual_style: False
|
||||||
fixed_std: 0.1
|
fixed_std: 0.1
|
||||||
num_modes: ${num_modes}
|
num_modes: ${num_modes}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
|
@ -54,9 +54,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -47,7 +47,7 @@ model:
|
|||||||
residual_style: False
|
residual_style: False
|
||||||
fixed_std: 0.1
|
fixed_std: 0.1
|
||||||
num_modes: ${num_modes}
|
num_modes: ${num_modes}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -109,15 +109,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -100,10 +100,9 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -107,15 +107,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,15 +95,14 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -58,9 +58,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -57,9 +57,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -58,9 +58,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -57,9 +57,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -57,9 +57,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -56,9 +56,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -57,9 +57,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -83,20 +83,19 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -90,13 +90,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -93,19 +93,18 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic_v:
|
critic_v:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -99,20 +99,18 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -82,6 +82,5 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -83,20 +83,19 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -90,13 +90,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -93,19 +93,18 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic_v:
|
critic_v:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -98,20 +98,18 @@ model:
|
|||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -82,6 +82,5 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -83,20 +83,19 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -81,7 +81,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -90,13 +90,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -84,7 +84,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -93,19 +93,18 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic_v:
|
critic_v:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -91,20 +91,18 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -82,7 +82,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -91,13 +91,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -74,7 +74,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -82,6 +82,5 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -87,22 +87,20 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -80,15 +80,14 @@ model:
|
|||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
residual_style: True
|
residual_style: True
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -45,7 +45,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -55,9 +55,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -45,7 +45,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -55,9 +55,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -45,7 +45,7 @@ model:
|
|||||||
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
cond_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
time_dim: 16
|
time_dim: 16
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
activation_type: ReLU
|
activation_type: ReLU
|
||||||
@ -55,9 +55,7 @@ model:
|
|||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
||||||
|
|
||||||
ema:
|
ema:
|
||||||
|
@ -94,13 +94,12 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -95,13 +95,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -94,13 +94,12 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -97,19 +97,18 @@ model:
|
|||||||
_target_: model.common.critic.CriticObsAct
|
_target_: model.common.critic.CriticObsAct
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
action_steps: ${act_steps}
|
action_steps: ${act_steps}
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
critic_v:
|
critic_v:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -100,15 +100,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -20,6 +20,7 @@ transition_dim: ${action_dim}
|
|||||||
denoising_steps: 100
|
denoising_steps: 100
|
||||||
ft_denoising_steps: 5
|
ft_denoising_steps: 5
|
||||||
cond_steps: 1
|
cond_steps: 1
|
||||||
|
img_cond_steps: 1
|
||||||
horizon_steps: 4
|
horizon_steps: 4
|
||||||
act_steps: 4
|
act_steps: 4
|
||||||
use_ddim: True
|
use_ddim: True
|
||||||
@ -121,6 +122,7 @@ model:
|
|||||||
backbone:
|
backbone:
|
||||||
_target_: model.common.vit.VitEncoder
|
_target_: model.common.vit.VitEncoder
|
||||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||||
|
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||||
cfg:
|
cfg:
|
||||||
patch_size: 8
|
patch_size: 8
|
||||||
depth: 1
|
depth: 1
|
||||||
@ -133,6 +135,7 @@ model:
|
|||||||
time_dim: 32
|
time_dim: 32
|
||||||
mlp_dims: [512, 512, 512]
|
mlp_dims: [512, 512, 512]
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
img_cond_steps: ${img_cond_steps}
|
||||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
@ -143,6 +146,7 @@ model:
|
|||||||
backbone:
|
backbone:
|
||||||
_target_: model.common.vit.VitEncoder
|
_target_: model.common.vit.VitEncoder
|
||||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||||
|
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||||
cfg:
|
cfg:
|
||||||
patch_size: 8
|
patch_size: 8
|
||||||
depth: 1
|
depth: 1
|
||||||
@ -150,15 +154,14 @@ model:
|
|||||||
num_heads: 4
|
num_heads: 4
|
||||||
embed_style: embed2
|
embed_style: embed2
|
||||||
embed_norm: 0
|
embed_norm: 0
|
||||||
obs_dim: ${obs_dim}
|
img_cond_steps: ${img_cond_steps}
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -103,15 +103,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -108,15 +108,13 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
ft_denoising_steps: ${ft_denoising_steps}
|
ft_denoising_steps: ${ft_denoising_steps}
|
||||||
transition_dim: ${transition_dim}
|
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
obs_dim: ${obs_dim}
|
obs_dim: ${obs_dim}
|
||||||
action_dim: ${action_dim}
|
action_dim: ${action_dim}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
denoising_steps: ${denoising_steps}
|
denoising_steps: ${denoising_steps}
|
||||||
device: ${device}
|
device: ${device}
|
@ -94,10 +94,9 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -18,6 +18,7 @@ obs_dim: 9
|
|||||||
action_dim: 7
|
action_dim: 7
|
||||||
transition_dim: ${action_dim}
|
transition_dim: ${action_dim}
|
||||||
cond_steps: 1
|
cond_steps: 1
|
||||||
|
img_cond_steps: 1
|
||||||
horizon_steps: 4
|
horizon_steps: 4
|
||||||
act_steps: 4
|
act_steps: 4
|
||||||
|
|
||||||
@ -100,6 +101,7 @@ model:
|
|||||||
backbone:
|
backbone:
|
||||||
_target_: model.common.vit.VitEncoder
|
_target_: model.common.vit.VitEncoder
|
||||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||||
|
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||||
cfg:
|
cfg:
|
||||||
patch_size: 8
|
patch_size: 8
|
||||||
depth: 1
|
depth: 1
|
||||||
@ -115,6 +117,7 @@ model:
|
|||||||
learn_fixed_std: True
|
learn_fixed_std: True
|
||||||
std_min: 0.01
|
std_min: 0.01
|
||||||
std_max: 0.2
|
std_max: 0.2
|
||||||
|
img_cond_steps: ${img_cond_steps}
|
||||||
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
@ -125,6 +128,7 @@ model:
|
|||||||
backbone:
|
backbone:
|
||||||
_target_: model.common.vit.VitEncoder
|
_target_: model.common.vit.VitEncoder
|
||||||
obs_shape: ${shape_meta.obs.rgb.shape}
|
obs_shape: ${shape_meta.obs.rgb.shape}
|
||||||
|
num_channel: ${eval:'${shape_meta.obs.rgb.shape[0]} * ${img_cond_steps}'} # each image patch is history concatenated
|
||||||
cfg:
|
cfg:
|
||||||
patch_size: 8
|
patch_size: 8
|
||||||
depth: 1
|
depth: 1
|
||||||
@ -132,10 +136,10 @@ model:
|
|||||||
num_heads: 4
|
num_heads: 4
|
||||||
embed_style: embed2
|
embed_style: embed2
|
||||||
embed_norm: 0
|
embed_norm: 0
|
||||||
obs_dim: ${obs_dim}
|
img_cond_steps: ${img_cond_steps}
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
@ -26,7 +26,7 @@ env:
|
|||||||
name: ${env_name}
|
name: ${env_name}
|
||||||
best_reward_threshold_for_success: 1
|
best_reward_threshold_for_success: 1
|
||||||
max_episode_steps: 300
|
max_episode_steps: 300
|
||||||
save_video: false
|
save_video: False
|
||||||
wrappers:
|
wrappers:
|
||||||
robomimic_lowdim:
|
robomimic_lowdim:
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
@ -95,10 +95,9 @@ model:
|
|||||||
transition_dim: ${transition_dim}
|
transition_dim: ${transition_dim}
|
||||||
critic:
|
critic:
|
||||||
_target_: model.common.critic.CriticObs
|
_target_: model.common.critic.CriticObs
|
||||||
obs_dim: ${obs_dim}
|
|
||||||
mlp_dims: [256, 256, 256]
|
mlp_dims: [256, 256, 256]
|
||||||
activation_type: Mish
|
activation_type: Mish
|
||||||
residual_style: True
|
residual_style: True
|
||||||
|
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
|
||||||
horizon_steps: ${horizon_steps}
|
horizon_steps: ${horizon_steps}
|
||||||
cond_steps: ${cond_steps}
|
|
||||||
device: ${device}
|
device: ${device}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user