dppo/agent/finetune/train_idql_diffusion_agent.py
Allen Z. Ren e0842e71dc
v0.5 to main (#10)
* v0.5 (#9)

* update idql configs

* update awr configs

* update dipo configs

* update qsm configs

* update dqm configs

* update project version to 0.5.0
2024-10-07 16:35:13 -04:00

368 lines
15 KiB
Python

"""
Implicit diffusion Q-learning (IDQL) trainer for diffusion policy.
Do not support pixel input right now.
"""
import os
import pickle
import einops
import numpy as np
import torch
import logging
import wandb
from copy import deepcopy
log = logging.getLogger(__name__)
from util.timer import Timer
from collections import deque
from agent.finetune.train_agent import TrainAgent
from util.scheduler import CosineAnnealingWarmupRestarts
class TrainIDQLDiffusionAgent(TrainAgent):
def __init__(self, cfg):
super().__init__(cfg)
# note the discount factor gamma here is applied to reward every act_steps, instead of every env step
self.gamma = cfg.train.gamma
# Wwarm up period for critic before actor updates
self.n_critic_warmup_itr = cfg.train.n_critic_warmup_itr
# Optimizer
self.actor_optimizer = torch.optim.AdamW(
self.model.actor.parameters(),
lr=cfg.train.actor_lr,
weight_decay=cfg.train.actor_weight_decay,
)
self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
self.actor_optimizer,
first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.actor_lr,
min_lr=cfg.train.actor_lr_scheduler.min_lr,
warmup_steps=cfg.train.actor_lr_scheduler.warmup_steps,
gamma=1.0,
)
self.critic_q_optimizer = torch.optim.AdamW(
self.model.critic_q.parameters(),
lr=cfg.train.critic_lr,
weight_decay=cfg.train.critic_weight_decay,
)
self.critic_v_optimizer = torch.optim.AdamW(
self.model.critic_v.parameters(),
lr=cfg.train.critic_lr,
weight_decay=cfg.train.critic_weight_decay,
)
self.critic_v_lr_scheduler = CosineAnnealingWarmupRestarts(
self.critic_v_optimizer,
first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.critic_lr,
min_lr=cfg.train.critic_lr_scheduler.min_lr,
warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
gamma=1.0,
)
self.critic_q_lr_scheduler = CosineAnnealingWarmupRestarts(
self.critic_q_optimizer,
first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
cycle_mult=1.0,
max_lr=cfg.train.critic_lr,
min_lr=cfg.train.critic_lr_scheduler.min_lr,
warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
gamma=1.0,
)
# Buffer size
self.buffer_size = cfg.train.buffer_size
# Actor params
self.use_expectile_exploration = cfg.train.use_expectile_exploration
# Scaling reward
self.scale_reward_factor = cfg.train.scale_reward_factor
# Updates
self.replay_ratio = cfg.train.replay_ratio
self.critic_tau = cfg.train.critic_tau
# Whether to use deterministic mode when sampling at eval
self.eval_deterministic = cfg.train.get("eval_deterministic", False)
# Sampling
self.num_sample = cfg.train.eval_sample_num
def run(self):
# make a FIFO replay buffer for obs, action, and reward
obs_buffer = deque(maxlen=self.buffer_size)
next_obs_buffer = deque(maxlen=self.buffer_size)
action_buffer = deque(maxlen=self.buffer_size)
reward_buffer = deque(maxlen=self.buffer_size)
terminated_buffer = deque(maxlen=self.buffer_size)
# Start training loop
timer = Timer()
run_results = []
cnt_train_step = 0
last_itr_eval = False
done_venv = np.zeros((1, self.n_envs))
while self.itr < self.n_train_itr:
# Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
options_venv = [{} for _ in range(self.n_envs)]
if self.itr % self.render_freq == 0 and self.render_video:
for env_ind in range(self.n_render):
options_venv[env_ind]["video_path"] = os.path.join(
self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
)
# Define train or eval - all envs restart
eval_mode = self.itr % self.val_freq == 0 and not self.force_train
self.model.eval() if eval_mode else self.model.train()
last_itr_eval = eval_mode
# Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) right after eval mode
firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
if self.reset_at_iteration or eval_mode or last_itr_eval:
prev_obs_venv = self.reset_env_all(options_venv=options_venv)
firsts_trajs[0] = 1
else:
# if done at the end of last iteration, the envs are just reset
firsts_trajs[0] = done_venv
reward_trajs = np.zeros((self.n_steps, self.n_envs))
# Collect a set of trajectories from env
for step in range(self.n_steps):
if step % 10 == 0:
print(f"Processed step {step} of {self.n_steps}")
# Select action
with torch.no_grad():
cond = {
"state": torch.from_numpy(prev_obs_venv["state"])
.float()
.to(self.device)
}
samples = (
self.model(
cond=cond,
deterministic=eval_mode and self.eval_deterministic,
num_sample=self.num_sample,
use_expectile_exploration=self.use_expectile_exploration,
)
.cpu()
.numpy()
) # n_env x horizon x act
action_venv = samples[:, : self.act_steps]
# Apply multi-step action
obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
self.venv.step(action_venv)
)
done_venv = terminated_venv | truncated_venv
reward_trajs[step] = reward_venv
firsts_trajs[step + 1] = done_venv
# add to buffer
if not eval_mode:
obs_venv_copy = obs_venv.copy()
for i in range(self.n_envs):
if truncated_venv[i]:
obs_venv_copy["state"][i] = info_venv[i]["final_obs"][
"state"
]
obs_buffer.append(prev_obs_venv["state"])
next_obs_buffer.append(obs_venv_copy["state"])
action_buffer.append(action_venv)
reward_buffer.append(reward_venv * self.scale_reward_factor)
terminated_buffer.append(terminated_venv)
# update for next step
prev_obs_venv = obs_venv
# count steps --- not acounting for done within action chunk
cnt_train_step += self.n_envs * self.act_steps if not eval_mode else 0
# Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
episodes_start_end = []
for env_ind in range(self.n_envs):
env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
for i in range(len(env_steps) - 1):
start = env_steps[i]
end = env_steps[i + 1]
if end - start > 1:
episodes_start_end.append((env_ind, start, end - 1))
if len(episodes_start_end) > 0:
reward_trajs_split = [
reward_trajs[start : end + 1, env_ind]
for env_ind, start, end in episodes_start_end
]
num_episode_finished = len(reward_trajs_split)
episode_reward = np.array(
[np.sum(reward_traj) for reward_traj in reward_trajs_split]
)
episode_best_reward = np.array(
[
np.max(reward_traj) / self.act_steps
for reward_traj in reward_trajs_split
]
)
avg_episode_reward = np.mean(episode_reward)
avg_best_reward = np.mean(episode_best_reward)
success_rate = np.mean(
episode_best_reward >= self.best_reward_threshold_for_success
)
else:
episode_reward = np.array([])
num_episode_finished = 0
avg_episode_reward = 0
avg_best_reward = 0
success_rate = 0
log.info("[WARNING] No episode completed within the iteration!")
# Update models
if not eval_mode:
num_batch = int(
self.n_steps * self.n_envs / self.batch_size * self.replay_ratio
)
obs_trajs = np.array(deepcopy(obs_buffer))
action_trajs = np.array(deepcopy(action_buffer))
next_obs_trajs = np.array(deepcopy(next_obs_buffer))
reward_trajs = np.array(deepcopy(reward_buffer))
terminated_trajs = np.array(deepcopy(terminated_buffer))
# flatten
obs_trajs = einops.rearrange(
obs_trajs,
"s e h d -> (s e) h d",
)
next_obs_trajs = einops.rearrange(
next_obs_trajs,
"s e h d -> (s e) h d",
)
action_trajs = einops.rearrange(
action_trajs,
"s e h d -> (s e) h d",
)
reward_trajs = reward_trajs.reshape(-1)
terminated_trajs = terminated_trajs.reshape(-1)
for _ in range(num_batch):
# Sample batch
inds = np.random.choice(len(obs_trajs), self.batch_size)
obs_b = torch.from_numpy(obs_trajs[inds]).float().to(self.device)
next_obs_b = (
torch.from_numpy(next_obs_trajs[inds]).float().to(self.device)
)
actions_b = (
torch.from_numpy(action_trajs[inds]).float().to(self.device)
)
reward_b = (
torch.from_numpy(reward_trajs[inds]).float().to(self.device)
)
terminated_b = (
torch.from_numpy(terminated_trajs[inds]).float().to(self.device)
)
# update critic value function
critic_loss_v = self.model.loss_critic_v(
{"state": obs_b}, actions_b
)
self.critic_v_optimizer.zero_grad()
critic_loss_v.backward()
self.critic_v_optimizer.step()
# update critic q function
critic_loss_q = self.model.loss_critic_q(
{"state": obs_b},
{"state": next_obs_b},
actions_b,
reward_b,
terminated_b,
self.gamma,
)
self.critic_q_optimizer.zero_grad()
critic_loss_q.backward()
self.critic_q_optimizer.step()
# update target q function
self.model.update_target_critic(self.critic_tau)
loss_critic = critic_loss_q.detach() + critic_loss_v.detach()
# Update policy with collected trajectories - no weighting
loss_actor = self.model.loss(
actions_b,
{"state": obs_b},
)
self.actor_optimizer.zero_grad()
loss_actor.backward()
if self.itr >= self.n_critic_warmup_itr:
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.actor.parameters(), self.max_grad_norm
)
self.actor_optimizer.step()
# Update lr
self.actor_lr_scheduler.step()
self.critic_v_lr_scheduler.step()
self.critic_q_lr_scheduler.step()
# Save model
if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
self.save_model()
# Log loss and save metrics
run_results.append(
{
"itr": self.itr,
"step": cnt_train_step,
}
)
if self.itr % self.log_freq == 0:
time = timer()
run_results[-1]["time"] = time
if eval_mode:
log.info(
f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"success rate - eval": success_rate,
"avg episode reward - eval": avg_episode_reward,
"avg best reward - eval": avg_best_reward,
"num episode - eval": num_episode_finished,
},
step=self.itr,
commit=False,
)
run_results[-1]["eval_success_rate"] = success_rate
run_results[-1]["eval_episode_reward"] = avg_episode_reward
run_results[-1]["eval_best_reward"] = avg_best_reward
else:
log.info(
f"{self.itr}: step {cnt_train_step:8d} | loss actor {loss_actor:8.4f} | reward {avg_episode_reward:8.4f} | t:{time:8.4f}"
)
if self.use_wandb:
wandb.log(
{
"total env step": cnt_train_step,
"loss - actor": loss_actor,
"loss - critic": loss_critic,
"avg episode reward - train": avg_episode_reward,
"num episode - train": num_episode_finished,
},
step=self.itr,
commit=True,
)
run_results[-1]["train_episode_reward"] = avg_episode_reward
with open(self.result_path, "wb") as f:
pickle.dump(run_results, f)
self.itr += 1