add evaluation agents and some example configs

This commit is contained in:
allenzren 2024-09-17 16:32:45 -04:00
parent bc52beca1e
commit c9f24ba0c3
13 changed files with 1120 additions and 14 deletions

View File

@ -83,7 +83,6 @@ python script/train.py --config-name=pre_diffusion_mlp \
See [here](cfg/pretraining.md) for details of the experiments in the paper.
## Usage - Fine-tuning
<!-- ### Set up pre-trained policy -->
@ -139,6 +138,17 @@ See [here](cfg/finetuning.md) for details of the experiments in the paper.
* D3IL environment can be visualized in GUI by `+env.render=True`, `env.n_envs=1`, and `train.render.num=1`. There is a basic script at `script/test_d3il_render.py`.
* Videos of trials in Robomimic tasks can be recorded by specifying `env.save_video=True`, `train.render.freq=<iterations>`, and `train.render.num=<num_video>` in fine-tuning configs.
## Usage - Evaluation
Pre-trained or fine-tuned policies can be evaluated without running the fine-tuning script now. Some example configs are provided under `cfg/{gym/robomimic/furniture}/eval}` including ones below. Set `base_policy_path` to override the default checkpoint.
```console
python script/train.py --config-name=eval_diffusion_mlp \
--config-dir=cfg/gym/eval/hopper-v2
python script/train.py --config-name=eval_{diffusion/gaussian}_mlp_{?img} \
--config-dir=cfg/robomimic/eval/can
python script/train.py --config-name=eval_diffusion_mlp \
--config-dir=cfg/furniture/eval/one_leg_low
```
## DPPO implementation
Our diffusion implementation is mostly based on [Diffuser](https://github.com/jannerm/diffuser) and at [`model/diffusion/diffusion.py`](model/diffusion/diffusion.py) and [`model/diffusion/diffusion_vpg.py`](model/diffusion/diffusion_vpg.py). PPO specifics are implemented at [`model/diffusion/diffusion_ppo.py`](model/diffusion/diffusion_ppo.py). The main training script is at [`agent/finetune/train_ppo_diffusion_agent.py`](agent/finetune/train_ppo_diffusion_agent.py) that follows [CleanRL](https://github.com/vwxyzjn/cleanrl).

116
agent/eval/eval_agent.py Normal file
View File

@ -0,0 +1,116 @@
"""
Parent eval agent class.
"""
import os
import numpy as np
import torch
import hydra
import logging
import random
log = logging.getLogger(__name__)
from env.gym_utils import make_async
class EvalAgent:
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.device = cfg.device
self.seed = cfg.get("seed", 42)
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
# Make vectorized env
self.env_name = cfg.env.name
env_type = cfg.env.get("env_type", None)
self.venv = make_async(
cfg.env.name,
env_type=env_type,
num_envs=cfg.env.n_envs,
asynchronous=True,
max_episode_steps=cfg.env.max_episode_steps,
wrappers=cfg.env.get("wrappers", None),
robomimic_env_cfg_path=cfg.get("robomimic_env_cfg_path", None),
shape_meta=cfg.get("shape_meta", None),
use_image_obs=cfg.env.get("use_image_obs", False),
render=cfg.env.get("render", False),
render_offscreen=cfg.env.get("save_video", False),
obs_dim=cfg.obs_dim,
action_dim=cfg.action_dim,
**cfg.env.specific if "specific" in cfg.env else {},
)
if not env_type == "furniture":
self.venv.seed(
[self.seed + i for i in range(cfg.env.n_envs)]
) # otherwise parallel envs might have the same initial states!
# isaacgym environments do not need seeding
self.n_envs = cfg.env.n_envs
self.n_cond_step = cfg.cond_steps
self.obs_dim = cfg.obs_dim
self.action_dim = cfg.action_dim
self.act_steps = cfg.act_steps
self.horizon_steps = cfg.horizon_steps
self.max_episode_steps = cfg.env.max_episode_steps
self.reset_at_iteration = cfg.env.get("reset_at_iteration", True)
self.furniture_sparse_reward = (
cfg.env.specific.get("sparse_reward", False)
if "specific" in cfg.env
else False
) # furniture specific, for best reward calculation
# Build model and load checkpoint
self.model = hydra.utils.instantiate(cfg.model)
# Eval params
self.n_steps = cfg.n_steps
self.best_reward_threshold_for_success = (
len(self.venv.pairs_to_assemble)
if env_type == "furniture"
else cfg.env.best_reward_threshold_for_success
)
# Logging, rendering
self.logdir = cfg.logdir
self.render_dir = os.path.join(self.logdir, "render")
self.result_path = os.path.join(self.logdir, "result.npz")
os.makedirs(self.render_dir, exist_ok=True)
self.n_render = cfg.render_num
self.render_video = cfg.env.get("save_video", False)
assert self.n_render <= self.n_envs, "n_render must be <= n_envs"
assert not (
self.n_render <= 0 and self.render_video
), "Need to set n_render > 0 if saving video"
def run(self):
pass
def reset_env_all(self, verbose=False, options_venv=None, **kwargs):
if options_venv is None:
options_venv = [
{k: v for k, v in kwargs.items()} for _ in range(self.n_envs)
]
obs_venv = self.venv.reset_arg(options_list=options_venv)
# convert to OrderedDict if obs_venv is a list of dict
if isinstance(obs_venv, list):
obs_venv = {
key: np.stack([obs_venv[i][key] for i in range(self.n_envs)])
for key in obs_venv[0].keys()
}
if verbose:
for index in range(self.n_envs):
logging.info(
f"<-- Reset environment {index} with options {options_venv[index]}"
)
return obs_venv
def reset_env(self, env_ind, verbose=False):
task = {}
obs = self.venv.reset_one_arg(env_ind=env_ind, options=task)
if verbose:
logging.info(f"<-- Reset environment {env_ind} with task {task}")
return obs

View File

@ -0,0 +1,119 @@
"""
Evaluate pre-trained/DPPO-fine-tuned diffusion policy.
"""
import os
import numpy as np
import torch
import logging
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent
class EvalDiffusionAgent(EvalAgent):
def __init__(self, cfg):
super().__init__(cfg)
def run(self):
# Start training loop
timer = Timer()
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Reset env before iteration starts
self.model.eval()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(cond=cond)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
if (
self.furniture_sparse_reward
): # only for furniture tasks, where reward only occurs in one env step
episode_best_reward = episode_reward
else:
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Log loss and save metrics
time = timer()
log.info(
f"eval: num episode {num_episode_finished:4d} | success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
np.savez(
self.result_path,
num_episode=num_episode_finished,
eval_success_rate=success_rate,
eval_episode_reward=avg_episode_reward,
eval_best_reward=avg_best_reward,
time=time,
)

View File

@ -0,0 +1,122 @@
"""
Evaluate pre-trained/DPPO-fine-tuned pixel-based diffusion policy.
"""
import os
import numpy as np
import torch
import logging
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent
class EvalImgDiffusionAgent(EvalAgent):
def __init__(self, cfg):
super().__init__(cfg)
# Set obs dim - we will save the different obs in batch in a dict
shape_meta = cfg.shape_meta
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs}
def run(self):
# Start training loop
timer = Timer()
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Reset env before iteration starts
self.model.eval()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
for key in self.obs_dims
} # batch each type of obs and put into dict
samples = self.model(cond=cond)
output_venv = (
samples.trajectories.cpu().numpy()
) # n_env x horizon x act
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
if (
self.furniture_sparse_reward
): # only for furniture tasks, where reward only occurs in one env step
episode_best_reward = episode_reward
else:
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Log loss and save metrics
time = timer()
log.info(
f"eval: num episode {num_episode_finished:4d} | success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
np.savez(
self.result_path,
num_episode=num_episode_finished,
eval_success_rate=success_rate,
eval_episode_reward=avg_episode_reward,
eval_best_reward=avg_best_reward,
time=time,
)

View File

@ -0,0 +1,117 @@
"""
Evaluate pre-trained/fine-tuned Gaussian/GMM policy.
"""
import os
import numpy as np
import torch
import logging
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent
class EvalGaussianAgent(EvalAgent):
def __init__(self, cfg):
super().__init__(cfg)
def run(self):
# Start training loop
timer = Timer()
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Reset env before iteration starts
self.model.eval()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = self.model(cond=cond, deterministic=True)
output_venv = samples.cpu().numpy()
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
if (
self.furniture_sparse_reward
): # only for furniture tasks, where reward only occurs in one env step
episode_best_reward = episode_reward
else:
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Log loss and save metrics
time = timer()
log.info(
f"eval: num episode {num_episode_finished:4d} | success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
np.savez(
self.result_path,
num_episode=num_episode_finished,
eval_success_rate=success_rate,
eval_episode_reward=avg_episode_reward,
eval_best_reward=avg_best_reward,
time=time,
)

View File

@ -0,0 +1,120 @@
"""
Evaluate pre-trained/fine-tuned Gaussian/GMM pixel-based policy.
"""
import os
import numpy as np
import torch
import logging
log = logging.getLogger(__name__)
from util.timer import Timer
from agent.eval.eval_agent import EvalAgent
class EvalImgGaussianAgent(EvalAgent):
def __init__(self, cfg):
super().__init__(cfg)
# Set obs dim - we will save the different obs in batch in a dict
shape_meta = cfg.shape_meta
self.obs_dims = {k: shape_meta.obs[k]["shape"] for k in shape_meta.obs.keys()}
def run(self):
# Start training loop
timer = Timer()
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Reset env before iteration starts
self.model.eval()
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
reward_trajs = np.empty((0, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
key: torch.from_numpy(prev_obs_venv[key]).float().to(self.device)
for key in self.obs_dims
}
samples = self.model(cond=cond, deterministic=True)
output_venv = samples.cpu().numpy()
action_venv = output_venv[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, done_venv, info_venv = self.venv.step(action_venv)
reward_trajs = np.vstack((reward_trajs, reward_venv[None]))
firsts_trajs[step + 1] = done_venv
prev_obs_venv = obs_venv
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
if (
self.furniture_sparse_reward
): # only for furniture tasks, where reward only occurs in one env step
episode_best_reward = episode_reward
else:
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Log loss and save metrics
time = timer()
log.info(
f"eval: num episode {num_episode_finished:4d} | success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
np.savez(
self.result_path,
num_episode=num_episode_finished,
eval_success_rate=success_rate,
eval_episode_reward=avg_episode_reward,
eval_best_reward=avg_best_reward,
time=time,
)

View File

@ -0,0 +1,67 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_diffusion_agent.EvalDiffusionAgent
name: ${env_name}_eval_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/furniture-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/furniture-pretrain/one_leg/one_leg_low_dim_pre_diffusion_mlp_ta8_td100/2024-07-22_20-01-16/checkpoint/state_8000.pt
normalization_path: ${oc.env:DPPO_DATA_DIR}/furniture/${env.specific.furniture}_${env.specific.randomness}/normalization.pth
seed: 42
device: cuda:0
env_name: ${env.specific.furniture}_${env.specific.randomness}_dim
obs_dim: 58
action_dim: 10
transition_dim: ${action_dim}
denoising_steps: 100
cond_steps: 1
horizon_steps: 8
act_steps: 8
use_ddim: True
ddim_steps: 5
n_steps: ${eval:'round(${env.max_episode_steps} / ${act_steps})'}
render_num: 0
env:
n_envs: 1000
name: ${env_name}
env_type: furniture
max_episode_steps: 700
best_reward_threshold_for_success: 1
specific:
headless: true
furniture: one_leg
randomness: low
normalization_path: ${normalization_path}
obs_steps: ${cond_steps}
act_steps: ${act_steps}
sparse_reward: True
model:
_target_: model.diffusion.diffusion.DiffusionModel
predict_epsilon: True
denoised_clip_value: 1.0
randn_clip_value: 3
#
use_ddim: ${use_ddim}
ddim_steps: ${ddim_steps}
network_path: ${base_policy_path}
network:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
time_dim: 32
mlp_dims: [1024, 1024, 1024, 1024, 1024, 1024, 1024]
cond_mlp_dims: [512, 64]
use_layernorm: True # needed for larger MLP
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -0,0 +1,62 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_diffusion_agent.EvalDiffusionAgent
name: ${env_name}_eval_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/gym-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/gym-pretrain/hopper-medium-v2_pre_diffusion_mlp_ta4_td20/2024-06-12_23-10-05/checkpoint/state_3000.pt
normalization_path: ${oc.env:DPPO_DATA_DIR}/gym/${env_name}/normalization.npz
seed: 42
device: cuda:0
env_name: hopper-medium-v2
obs_dim: 11
action_dim: 3
transition_dim: ${action_dim}
denoising_steps: 20
cond_steps: 1
horizon_steps: 4
act_steps: 4
n_steps: 500 # each episode can take maximum (max_episode_steps / act_steps, =250 right now) steps but may finish earlier in gym. We only count episodes finished within n_steps for evaluation.
render_num: 0
env:
n_envs: 40
name: ${env_name}
max_episode_steps: 1000
reset_at_iteration: False
save_video: False
best_reward_threshold_for_success: 3 # success rate not relevant for gym tasks
wrappers:
mujoco_locomotion_lowdim:
normalization_path: ${normalization_path}
multi_step:
n_obs_steps: ${cond_steps}
n_action_steps: ${act_steps}
max_episode_steps: ${env.max_episode_steps}
reset_within_step: True
model:
_target_: model.diffusion.diffusion.DiffusionModel
predict_epsilon: True
denoised_clip_value: 1.0
#
network_path: ${base_policy_path}
network:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
time_dim: 16
mlp_dims: [512, 512, 512]
activation_type: ReLU
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -0,0 +1,66 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_diffusion_agent.EvalDiffusionAgent
name: ${env_name}_eval_diffusion_mlp_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/can/can_pre_diffusion_mlp_ta4_td20/2024-06-28_13-29-54/checkpoint/state_5000.pt # use 8000 for comparing policy parameterizations
robomimic_env_cfg_path: cfg/robomimic/env_meta/${env_name}.json
normalization_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env_name}/normalization.npz
seed: 42
device: cuda:0
env_name: can
obs_dim: 23
action_dim: 7
transition_dim: ${action_dim}
denoising_steps: 20
cond_steps: 1
horizon_steps: 4
act_steps: 4
n_steps: 300 # each episode takes max_episode_steps / act_steps steps
render_num: 0
env:
n_envs: 50
name: ${env_name}
best_reward_threshold_for_success: 1
max_episode_steps: 300
save_video: False
wrappers:
robomimic_lowdim:
normalization_path: ${normalization_path}
low_dim_keys: ['robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos',
'object'] # same order of preprocessed observations
multi_step:
n_obs_steps: ${cond_steps}
n_action_steps: ${act_steps}
max_episode_steps: ${env.max_episode_steps}
reset_within_step: True
model:
_target_: model.diffusion.diffusion.DiffusionModel
predict_epsilon: True
denoised_clip_value: 1.0
randn_clip_value: 3
#
network_path: ${base_policy_path}
network:
_target_: model.diffusion.mlp_diffusion.DiffusionMLP
time_dim: 16
mlp_dims: [512, 512, 512]
residual_style: True
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -0,0 +1,98 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_diffusion_img_agent.EvalImgDiffusionAgent
name: ${env_name}_eval_diffusion_mlp_img_ta${horizon_steps}_td${denoising_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/can/can_pre_diffusion_mlp_img_ta4_td100/2024-07-30_22-23-55/checkpoint/state_5000.pt
robomimic_env_cfg_path: cfg/robomimic/env_meta/${env_name}-img.json
normalization_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env_name}-img/normalization.npz
seed: 42
device: cuda:0
env_name: can
obs_dim: 9
action_dim: 7
transition_dim: ${action_dim}
denoising_steps: 100
cond_steps: 1
img_cond_steps: 1
horizon_steps: 4
act_steps: 4
use_ddim: True
ddim_steps: 5
n_steps: 300 # each episode takes max_episode_steps / act_steps steps
render_num: 0
env:
n_envs: 50
name: ${env_name}
best_reward_threshold_for_success: 1
max_episode_steps: 300
save_video: False
use_image_obs: True
wrappers:
robomimic_image:
normalization_path: ${normalization_path}
low_dim_keys: ['robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos']
image_keys: ['robot0_eye_in_hand_image']
shape_meta: ${shape_meta}
multi_step:
n_obs_steps: ${cond_steps}
n_action_steps: ${act_steps}
max_episode_steps: ${env.max_episode_steps}
reset_within_step: True
shape_meta:
obs:
rgb:
shape: [3, 96, 96]
state:
shape: [9]
action:
shape: [7]
model:
_target_: model.diffusion.diffusion.DiffusionModel
predict_epsilon: True
denoised_clip_value: 1.0
randn_clip_value: 3
#
use_ddim: ${use_ddim}
ddim_steps: ${ddim_steps}
network_path: ${base_policy_path}
network:
_target_: model.diffusion.mlp_diffusion.VisionDiffusionMLP
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
embed_dim: 128
num_heads: 4
embed_style: embed2
embed_norm: 0
augment: False
spatial_emb: 128
time_dim: 32
mlp_dims: [512, 512, 512]
residual_style: True
img_cond_steps: ${img_cond_steps}
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
obs_dim: ${obs_dim}
action_dim: ${action_dim}
denoising_steps: ${denoising_steps}
device: ${device}

View File

@ -0,0 +1,60 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_gaussian_agent.EvalGaussianAgent
name: ${env_name}_eval_gaussian_mlp_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/can/can_pre_gaussian_mlp_ta4/2024-06-28_13-31-00/checkpoint/state_5000.pt
robomimic_env_cfg_path: cfg/robomimic/env_meta/${env_name}.json
normalization_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env_name}/normalization.npz
seed: 42
device: cuda:0
env_name: can
obs_dim: 23
action_dim: 7
transition_dim: ${action_dim}
cond_steps: 1
horizon_steps: 4
act_steps: 4
n_steps: 300 # each episode takes max_episode_steps / act_steps steps
render_num: 0
env:
n_envs: 50
name: ${env_name}
best_reward_threshold_for_success: 1
max_episode_steps: 300
save_video: False
wrappers:
robomimic_lowdim:
normalization_path: ${normalization_path}
low_dim_keys: ['robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos',
'object'] # same order of preprocessed observations
multi_step:
n_obs_steps: ${cond_steps}
n_action_steps: ${act_steps}
max_episode_steps: ${env.max_episode_steps}
reset_within_step: True
model:
_target_: model.common.gaussian.GaussianModel
randn_clip_value: 3
#
network_path: ${base_policy_path}
network:
_target_: model.common.mlp_gaussian.Gaussian_MLP
mlp_dims: [512, 512, 512]
residual_style: True
fixed_std: 0.1
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
device: ${device}

View File

@ -0,0 +1,87 @@
defaults:
- _self_
hydra:
run:
dir: ${logdir}
_target_: agent.eval.eval_gaussian_img_agent.EvalImgGaussianAgent
name: ${env_name}_eval_gaussian_mlp_img_ta${horizon_steps}
logdir: ${oc.env:DPPO_LOG_DIR}/robomimic-eval/${name}/${now:%Y-%m-%d}_${now:%H-%M-%S}_${seed}
base_policy_path: ${oc.env:DPPO_LOG_DIR}/robomimic-pretrain/can/can_pre_gaussian_mlp_img_ta4/2024-07-28_21-54-40/checkpoint/state_1000.pt
robomimic_env_cfg_path: cfg/robomimic/env_meta/${env_name}-img.json
normalization_path: ${oc.env:DPPO_DATA_DIR}/robomimic/${env_name}-img/normalization.npz
seed: 42
device: cuda:0
env_name: can
obs_dim: 9
action_dim: 7
transition_dim: ${action_dim}
cond_steps: 1
img_cond_steps: 1
horizon_steps: 4
act_steps: 4
n_steps: 300 # each episode takes max_episode_steps / act_steps steps
render_num: 0
env:
n_envs: 50
name: ${env_name}
best_reward_threshold_for_success: 1
max_episode_steps: 300
save_video: False
use_image_obs: True
wrappers:
robomimic_image:
normalization_path: ${normalization_path}
low_dim_keys: ['robot0_eef_pos',
'robot0_eef_quat',
'robot0_gripper_qpos']
image_keys: ['robot0_eye_in_hand_image']
shape_meta: ${shape_meta}
multi_step:
n_obs_steps: ${cond_steps}
n_action_steps: ${act_steps}
max_episode_steps: ${env.max_episode_steps}
reset_within_step: True
shape_meta:
obs:
rgb:
shape: [3, 96, 96]
state:
shape: [9]
action:
shape: [7]
model:
_target_: model.common.gaussian.GaussianModel
randn_clip_value: 3
#
network_path: ${base_policy_path}
network:
_target_: model.common.mlp_gaussian.Gaussian_VisionMLP
backbone:
_target_: model.common.vit.VitEncoder
obs_shape: ${shape_meta.obs.rgb.shape}
num_channel: ${eval:'3 * ${img_cond_steps}'} # each image patch is history concatenated
img_h: ${shape_meta.obs.rgb.shape[1]}
img_w: ${shape_meta.obs.rgb.shape[2]}
cfg:
patch_size: 8
depth: 1
embed_dim: 128
num_heads: 4
embed_style: embed2
embed_norm: 0
augment: False
spatial_emb: 128
mlp_dims: [512, 512, 512]
residual_style: True
fixed_std: 0.1
cond_dim: ${eval:'${obs_dim} * ${cond_steps}'}
horizon_steps: ${horizon_steps}
transition_dim: ${transition_dim}
horizon_steps: ${horizon_steps}
device: ${device}

View File

@ -8,7 +8,6 @@ Annotated DDIM/DDPM: https://nn.labml.ai/diffusion/stable_diffusion/sampler/ddpm
"""
from typing import Union
import logging
import torch
from torch import nn
@ -19,6 +18,7 @@ log = logging.getLogger(__name__)
from model.diffusion.sampling import (
extract,
cosine_beta_schedule,
make_timesteps,
)
from collections import namedtuple
@ -35,10 +35,14 @@ class DiffusionModel(nn.Module):
action_dim,
network_path=None,
device="cuda:0",
# Various clipping
denoised_clip_value=1.0,
randn_clip_value=10,
final_action_clip_value=None,
eps_clip_value=None, # DDIM only
# DDPM parameters
denoising_steps=100,
predict_epsilon=True,
denoised_clip_value=1.0,
# DDIM sampling
use_ddim=False,
ddim_discretize='uniform',
@ -51,11 +55,22 @@ class DiffusionModel(nn.Module):
self.obs_dim = obs_dim
self.action_dim = action_dim
self.denoising_steps = int(denoising_steps)
self.denoised_clip_value = denoised_clip_value
self.predict_epsilon = predict_epsilon
self.use_ddim = use_ddim
self.ddim_steps = ddim_steps
# Clip noise value at each denoising step
self.denoised_clip_value = denoised_clip_value
# Whether to clamp the final sampled action between [-1, 1]
self.final_action_clip_value = final_action_clip_value
# For each denoising step, we clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
self.randn_clip_value = randn_clip_value
# Clip epsilon for numerical stability
self.eps_clip_value = eps_clip_value
# Set up models
self.network = network.to(device)
if network_path is not None:
@ -154,7 +169,7 @@ class DiffusionModel(nn.Module):
# ---------- Sampling ----------#
def p_mean_var(self, x, t, cond=None, index=None):
def p_mean_var(self, x, t, cond, index=None):
noise = self.network(x, t, cond=cond)
# Predict x_0
@ -183,12 +198,16 @@ class DiffusionModel(nn.Module):
# re-calculate noise based on clamped x_recon - default to false in HF, but let's use it here
noise = (x - alpha ** (0.5) * x_recon) / sqrt_one_minus_alpha
# Clip epsilon for numerical stability in policy gradient - not sure if this is helpful yet, but the value can be huge sometimes. This has no effect if DDPM is used
if self.use_ddim and self.eps_clip_value is not None:
noise.clamp_(-self.eps_clip_value, self.eps_clip_value)
# Get mu
if self.use_ddim:
"""
μ = αₜ x₀ + (1-αₜ - σₜ²) ε
var should be zero here as self.ddim_eta=0
eta=0
"""
sigma = extract(self.ddim_sigmas, index, x.shape)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * noise
@ -209,13 +228,56 @@ class DiffusionModel(nn.Module):
return mu, logvar
@torch.no_grad()
def forward(
self,
cond,
return_chain=True,
**kwargs,
):
raise NotImplementedError
def forward(self, cond):
"""
Forward pass for sampling actions. Used in evaluating pre-trained/fine-tuned policy. Not modifying diffusion clipping
Args:
cond: dict with key state/rgb; more recent obs at the end
state: (B, To, Do)
rgb: (B, To, C, H, W)
Return:
Sample: namedtuple with fields:
trajectories: (B, Ta, Da)
"""
device = self.betas.device
sample_data = cond["state"] if "state" in cond else cond["rgb"]
B = len(sample_data)
# Loop
x = torch.randn((B, self.horizon_steps, self.action_dim), device=device)
if self.use_ddim:
t_all = self.ddim_t
else:
t_all = list(reversed(range(self.denoising_steps)))
for i, t in enumerate(t_all):
t_b = make_timesteps(B, t, device)
index_b = make_timesteps(B, i, device)
mean, logvar = self.p_mean_var(
x=x,
t=t_b,
cond=cond,
index=index_b,
)
std = torch.exp(0.5 * logvar)
# Determine noise level
if self.use_ddim:
std = torch.zeros_like(std)
else:
if t == 0:
std = torch.zeros_like(std)
else:
std = torch.clip(std, min=1e-3)
noise = torch.randn_like(x).clamp_(
-self.randn_clip_value, self.randn_clip_value
)
x = mean + std * noise
# clamp action at final step
if self.final_action_clip_value is not None and i == len(t_all) - 1:
x = torch.clamp(x, -self.final_action_clip_value, self.final_action_clip_value)
return Sample(x, None)
# ---------- Supervised training ----------#
@ -229,7 +291,7 @@ class DiffusionModel(nn.Module):
def p_losses(
self,
x_start,
cond: Union[dict, torch.Tensor],
cond: dict,
t,
):
"""