reppo/reppo_alg/torchrl/reppo.py
2025-07-21 18:31:20 -04:00

736 lines
26 KiB
Python

from dataclasses import dataclass, replace
import functools
import os
import random
import sys
import copy
import time
import numpy as np
import tqdm
from omegaconf import DictConfig, OmegaConf
import wandb
from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
try:
# Required for avoiding IsaacGym import error
import isaacgym
except ImportError:
pass
import hydra
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchinfo import summary
from tensordict import TensorDict
from torch.amp import GradScaler
from reppo_alg.torchrl.envs import make_envs
from reppo_alg.network_utils.torch_models import Actor, Critic
torch.set_float32_matmul_precision("medium")
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if sys.platform != "darwin":
os.environ["MUJOCO_GL"] = "egl"
else:
os.environ["MUJOCO_GL"] = "glfw"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
@dataclass(slots=True)
class TrainState:
device: torch.device
obs: torch.Tensor
critic_obs: torch.Tensor
actor: Actor
old_actor: Actor
critic: Critic
normalizer: EmpiricalNormalization
critic_normalizer: EmpiricalNormalization
actor_optimizer: optim.Optimizer
critic_optimizer: optim.Optimizer
scaler: GradScaler
def compile(self):
self.actor.compile()
self.old_actor.compile()
self.critic.compile()
self.normalizer.compile()
self.critic_normalizer.compile()
def get_autocast_context(cfg: DictConfig):
amp_enabled = (
cfg.platform.amp_enabled and cfg.platform.cuda and torch.cuda.is_available()
)
amp_device = (
"cuda"
if cfg.platform.cuda and torch.cuda.is_available()
else "mps"
if cfg.platform.cuda and torch.backends.mps.is_available()
else "cpu"
)
amp_dtype = torch.bfloat16 if cfg.platform.amp_dtype == "bf16" else torch.float32
return functools.partial(
torch.amp.autocast,
device_type=amp_device,
dtype=amp_dtype,
enabled=amp_enabled,
)
def make_collect_fn(cfg: DictConfig, env):
autocast = get_autocast_context(cfg)
asymmetric_obs = env.asymmetric_obs
def collect_fn(
train_state: TrainState,
) -> tuple[TrainState, TensorDict, list[dict]]:
transitions = []
info_list = []
obs = train_state.obs
critic_obs = train_state.critic_obs
for _ in range(cfg.hyperparameters.num_steps):
with autocast():
norm_obs = train_state.normalizer(obs)
norm_critic_obs = train_state.critic_normalizer(critic_obs)
with torch.inference_mode():
pi, _, _, _ = train_state.actor(norm_obs)
actions = pi.sample()
next_obs, rewards, dones, truncations, infos = env.step(actions)
if asymmetric_obs:
next_critic_obs = infos["observations"]["critic"]
else:
next_critic_obs = next_obs
with (
torch.inference_mode(),
autocast(),
):
if (
cfg.env.get("has_final_obs", False)
and cfg.env.get("partial_reset", False)
and "final_observation" in infos
):
_next_obs = infos["final_observation"]
_next_critic_obs = _next_obs
else:
_next_obs = next_obs
_next_critic_obs = next_critic_obs
norm_next_obs = train_state.normalizer(_next_obs)
next_pi, _, temperature, _ = train_state.actor(norm_next_obs)
next_actions = next_pi.sample()
next_log_probs = next_pi.log_prob(
next_actions.clip(-1 + 1e-6, 1 - 1e-6)
).sum(-1)
norm_next_critic_obs = train_state.critic_normalizer(_next_critic_obs)
next_value, _, _, next_embedding = train_state.critic(
norm_next_critic_obs, next_actions
)
rewards = (
rewards - cfg.hyperparameters.gamma * next_log_probs * temperature
)
transitions.append(
TensorDict(
{
"observations": norm_obs,
"critic_observations": norm_critic_obs,
"actions": actions,
"log_probs": pi.log_prob(actions.clip(-0.999, 0.999)).sum(-1),
"rewards": rewards.unsqueeze(-1),
"next_embeddings": next_embedding,
"next_values": next_value.unsqueeze(-1),
"dones": dones.unsqueeze(-1).float(),
"truncations": truncations.unsqueeze(-1).float(),
},
batch_size=(env.num_envs,),
)
)
info_list.append(infos)
obs = next_obs
critic_obs = next_critic_obs
train_state = replace(train_state, obs=obs, critic_obs=critic_obs)
return (
train_state,
torch.stack(transitions, dim=0),
info_list,
)
return collect_fn
def make_postprocess_fn(cfg: DictConfig, env):
@torch.compiler.disable()
def compute_gve(rewards, dones, truncated, next_values, device: torch.device):
gves = []
last_gve = 0
truncated[-1] = 1.0
for t in reversed(range(cfg.hyperparameters.num_steps)):
lambda_sum = (
cfg.hyperparameters.lmbda * last_gve
+ (1.0 - cfg.hyperparameters.lmbda) * next_values[t]
)
delta = cfg.hyperparameters.gamma * torch.where(
truncated[t].bool(), next_values[t], (1.0 - dones[t]) * lambda_sum
)
last_gve = rewards[t] + delta
gves.insert(0, last_gve)
return gves
def postprocess(train_state: TrainState, transition: TensorDict):
gve = compute_gve(
rewards=transition["rewards"],
dones=transition["dones"],
truncated=transition["truncations"],
next_values=transition["next_values"],
device=train_state.device,
)
# Flatten all time and environment dimensions into a single batch dimension
data = TensorDict(
{
"observations": transition["observations"],
"critic_observations": transition["critic_observations"],
"actions": transition["actions"],
"rewards": transition["rewards"],
"next_embeddings": transition["next_embeddings"],
"next_values": transition["next_values"],
"dones": transition["dones"],
"truncations": transition["truncations"],
"gve": torch.stack(gve),
},
batch_size=(
cfg.hyperparameters.num_steps,
cfg.hyperparameters.num_envs,
),
device=train_state.device,
)
return data.float().flatten(0, 1).detach()
return postprocess
def make_critic_update_fn(cfg: DictConfig, train_state: TrainState):
autocast = get_autocast_context(cfg)
def update(data: TensorDict):
qnet = train_state.critic
q_optimizer = train_state.critic_optimizer
with autocast():
critic_observations = data["critic_observations"]
actions = data["actions"]
targets = data["gve"]
target_embeddings = data["next_embeddings"]
truncations = data["truncations"].squeeze(-1)
if cfg.env.get("partial_reset", False):
truncation_mask = torch.ones_like(
truncations, dtype=torch.bool, device=train_state.device
)
else:
truncation_mask = 1.0 - truncations
qf_target_dist = hl_gauss(
targets,
cfg.hyperparameters.vmin,
cfg.hyperparameters.vmax,
cfg.hyperparameters.num_bins,
)
_, qf1, embedding, _ = qnet(critic_observations, actions)
qf_loss = -(
truncation_mask
* torch.sum(qf_target_dist * F.log_softmax(qf1, dim=-1), dim=-1)
).mean()
embedding_loss = (
truncation_mask
* F.mse_loss(
embedding,
target_embeddings,
reduction="none",
).mean(dim=-1)
).mean()
qf_loss = qf_loss + cfg.hyperparameters.aux_loss_mult * embedding_loss
q_optimizer.zero_grad(set_to_none=True)
train_state.scaler.scale(qf_loss).backward()
train_state.scaler.unscale_(q_optimizer)
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(), max_norm=cfg.hyperparameters.max_grad_norm
)
train_state.scaler.step(q_optimizer)
train_state.scaler.update()
logs_dict = {
"critic_grad_norm": critic_grad_norm.detach(),
"qf_loss": qf_loss.detach(),
"qf_max": targets.max().detach(),
"qf_min": targets.min().detach(),
"qf_mean": targets.mean().detach(),
"embedding_loss": embedding_loss.detach(),
}
return logs_dict
return update
def make_actor_update_fn(cfg: DictConfig, train_state: TrainState):
autocast = get_autocast_context(cfg)
def update(data: TensorDict):
actor = train_state.actor
old_actor = train_state.old_actor
qnet = train_state.critic
actor_optimizer = train_state.actor_optimizer
scaler = train_state.scaler
critic_obs = data["critic_observations"]
with autocast():
pi, _, temperature, beta = actor(data["observations"])
actions = pi.rsample()
log_probs = pi.log_prob(actions.clip(-1 + 1e-6, 1 - 1e-6)).sum(-1)
entropy = -log_probs
qf, _, _, _ = qnet(critic_obs, actions)
actor_loss = -qf + temperature.detach() * log_probs
# compute KL
old_pi, _, _, _ = old_actor(data["observations"])
old_pi_actions = old_pi.sample((16,)).clip(-1 + 1e-6, 1 - 1e-6)
old_log_probs = old_pi.log_prob(old_pi_actions).sum(-1).mean(0)
new_pi_log_probs = pi.log_prob(old_pi_actions).sum(-1).mean(0)
kl = old_log_probs - new_pi_log_probs
if cfg.hyperparameters.actor_kl_clip_mode == "clipped":
actor_loss = torch.where(
kl < cfg.hyperparameters.kl_bound,
actor_loss,
kl * beta.detach(),
).mean()
elif cfg.hyperparameters.actor_kl_clip_mode == "full":
actor_loss = actor_loss + kl * beta.detach()
elif cfg.hyperparameters.actor_kl_clip_mode == "value":
actor_loss = actor_loss
else:
raise ValueError(
f"Unknown actor kl clip mode: {cfg.hyperparameters.actor_kl_clip_mode}"
)
# temperature updates
target_entropy = (
actions.shape[-1] * cfg.hyperparameters.ent_target_mult
) # -0.5 * np.prod(envs.action_space.shape)
entropy_loss = (target_entropy + entropy).detach().mean() * temperature
lagrangian_loss = (
-beta * (kl - cfg.hyperparameters.kl_bound).mean().detach()
)
actor_loss = (actor_loss + entropy_loss + lagrangian_loss).mean()
actor_optimizer.zero_grad(set_to_none=True)
scaler.scale(actor_loss).backward()
scaler.unscale_(actor_optimizer)
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(), max_norm=cfg.hyperparameters.max_grad_norm
)
scaler.step(actor_optimizer)
scaler.update()
logs_dict = {
"actor_grad_norm": actor_grad_norm.detach(),
"actor_loss": actor_loss.detach(),
"kl": kl.detach(),
"entropy": entropy.detach(),
"temperature": temperature.detach(),
"lagrangian": beta.detach(),
"entropy_loss": entropy_loss.detach(),
"lagrangian_loss": lagrangian_loss.detach(),
}
return logs_dict
return update
def make_evaluate_fn(cfg: DictConfig, eval_envs):
autocast = get_autocast_context(cfg)
@torch.inference_mode()
def evaluate(
train_state: TrainState, stochastic_eval: bool = False
) -> tuple[int | float | bool, int | float | bool]:
train_state.normalizer.eval()
num_eval_envs = eval_envs.num_envs
episode_returns = torch.zeros(num_eval_envs, device=train_state.device)
episode_lengths = torch.zeros(num_eval_envs, device=train_state.device)
done_masks = torch.zeros(
num_eval_envs, dtype=torch.bool, device=train_state.device
)
if cfg.env.type == "isaaclab" or cfg.env.asymmetric_observation:
obs, _ = eval_envs.reset(random_start_init=False)
else:
obs = eval_envs.reset()
# Run for a fixed number of steps
for i in range(eval_envs.max_episode_steps):
with autocast():
obs = train_state.normalizer(obs)
action_dist, det_actions, _, _ = train_state.actor(obs)
if stochastic_eval:
actions = action_dist.sample()
else:
actions = det_actions
next_obs, rewards, dones, _, infos = eval_envs.step(actions)
episode_returns = torch.where(
~done_masks, episode_returns + rewards, episode_returns
)
episode_lengths = torch.where(
~done_masks, episode_lengths + 1, episode_lengths
)
done_masks = torch.logical_or(done_masks, dones)
if done_masks.all():
break
obs = next_obs
train_state.normalizer.train()
if cfg.env.type == "maniskill":
# combine log_infos
info = {
"info_return": infos["log_info"]["return"].mean(),
"episode_len": infos["log_info"]["episode_len"].float().mean(),
"success": infos["log_info"]["success"].float().mean(),
"return": episode_returns.mean().item(),
}
else:
info = {}
return episode_returns.mean().item(), episode_lengths.mean().item(), info
return evaluate
def configure_platform(cfg: DictConfig) -> DictConfig:
cfg.platform.amp_enabled = (
cfg.platform.amp_enabled and cfg.platform.cuda and torch.cuda.is_available()
)
cfg.platform.amp_device = (
"cuda"
if cfg.platform.cuda and torch.cuda.is_available()
else "mps"
if cfg.platform.cuda and torch.backends.mps.is_available()
else "cpu"
)
return cfg
@hydra.main(
version_base=None,
config_path="../../config",
config_name="reppo",
)
def main(cfg):
cfg = configure_platform(cfg)
run_name = f"{cfg.env.name}_torch_{cfg.seed}"
scaler = GradScaler(
enabled=cfg.platform.amp_enabled and cfg.platform.amp_dtype == torch.float16
)
num_batches = cfg.hyperparameters.num_mini_batches
batch_size = (
cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps // num_batches
)
wandb.init(
project=cfg.wandb.project,
name=run_name,
config=OmegaConf.to_container(cfg),
save_code=True,
)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.backends.cudnn.deterministic = cfg.platform.torch_deterministic
if not cfg.platform.cuda:
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device(f"cuda:{cfg.platform.device_rank}")
elif torch.backends.mps.is_available():
device = torch.device(f"mps:{cfg.platform.device_rank}")
else:
raise ValueError("No GPU available")
print(f"Using device: {device}")
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
n_act = envs.num_actions
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0]
)
else:
n_critic_obs = n_obs
if cfg.hyperparameters.normalize_env:
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
critic_obs_normalizer = EmpiricalNormalization(
shape=n_critic_obs, device=device
)
else:
obs_normalizer = nn.Identity()
critic_obs_normalizer = nn.Identity()
actor = Actor(
n_obs=n_obs,
n_act=n_act,
ent_start=cfg.hyperparameters.ent_start,
kl_start=cfg.hyperparameters.kl_start,
hidden_dim=cfg.hyperparameters.actor_hidden_dim,
use_norm=cfg.hyperparameters.use_actor_norm,
layers=cfg.hyperparameters.num_actor_layers,
min_std=cfg.hyperparameters.actor_min_std,
device=device,
)
old_actor = copy.deepcopy(actor)
qnet = Critic(
n_obs=n_critic_obs,
n_act=n_act,
num_atoms=cfg.hyperparameters.num_bins,
vmin=cfg.hyperparameters.vmin,
vmax=cfg.hyperparameters.vmax,
hidden_dim=cfg.hyperparameters.critic_hidden_dim,
use_norm=cfg.hyperparameters.use_critic_norm,
use_encoder_norm=False,
encoder_layers=cfg.hyperparameters.num_critic_encoder_layers,
head_layers=cfg.hyperparameters.num_critic_head_layers,
pred_layers=cfg.hyperparameters.num_critic_pred_layers,
device=device,
)
q_optimizer = optim.AdamW(
list(qnet.parameters()),
lr=torch.tensor(cfg.hyperparameters.lr, device=device),
)
actor_optimizer = optim.AdamW(
list(actor.parameters()),
lr=torch.tensor(cfg.hyperparameters.lr, device=device),
)
if envs.asymmetric_obs:
obs, critic_obs = envs.reset_with_critic_obs()
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
else:
obs = envs.reset()
critic_obs = obs
train_state = TrainState(
obs=obs,
critic_obs=critic_obs,
actor=actor,
old_actor=old_actor,
critic=qnet,
normalizer=obs_normalizer,
critic_normalizer=critic_obs_normalizer,
actor_optimizer=actor_optimizer,
critic_optimizer=q_optimizer,
device=device,
scaler=scaler,
)
print(
summary(
train_state.critic,
input_data=(critic_obs[:1], torch.zeros((1, n_act), device=device)),
depth=10,
)
)
print(summary(train_state.actor, input_data=(obs[:1],), depth=10))
# create functions
collect_fn = make_collect_fn(cfg, envs)
postprocess_fn = make_postprocess_fn(cfg, envs)
update_critic = make_critic_update_fn(cfg, train_state)
update_actor = make_actor_update_fn(cfg, train_state)
evaluate = make_evaluate_fn(cfg, eval_envs)
if cfg.platform.compile:
mode = "max-autotune-no-cudagraphs"
update_critic = torch.compile(update_critic, mode=mode)
update_actor = torch.compile(update_actor, mode=mode)
postprocess_fn = torch.compile(postprocess_fn, mode=mode)
train_state.compile()
# TODO: Support checkpoint loading
# if cfg.checkpoint_path:
# # Load checkpoint if specified
# torch_checkpoint = torch.load(
# f"{cfg.checkpoint_path}", map_location=device, weights_only=False
# )
# actor.load_state_dict(torch_checkpoint["actor_state_dict"])
# obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
# critic_obs_normalizer.load_state_dict(
# torch_checkpoint["critic_obs_normalizer_state"]
# )
# qnet.load_state_dict(torch_checkpoint["qnet_state_dict"])
# qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
# global_step = torch_checkpoint["global_step"]
# else:
global_step = 0
total_env_steps = (
cfg.hyperparameters.total_time_steps
// (cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps)
+ 1
)
pbar = tqdm.tqdm(total=cfg.hyperparameters.total_time_steps, initial=global_step)
start_time = None
desc = ""
eval_interval = total_env_steps // cfg.hyperparameters.num_eval
stochastic_eval = cfg.env.get("stochastic_eval", False)
while global_step < total_env_steps:
if start_time is None and global_step >= cfg.measure_burnin:
start_time = time.time()
measure_burnin = global_step
train_state, transition, infos = collect_fn(train_state)
data = postprocess_fn(train_state, transition)
for _ in range(cfg.hyperparameters.num_epochs):
indices = torch.randperm(
cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps,
device=device,
)
data = data[indices].contiguous()
for j in range(num_batches):
mini_batch = data[j * batch_size : (j + 1) * batch_size]
critic_logs_dict = update_critic(mini_batch)
actor_logs_dict = update_actor(mini_batch)
logs_dict = {
**critic_logs_dict,
**actor_logs_dict,
}
for param, target_param in zip(actor.parameters(), old_actor.parameters()):
target_param.data.copy_(param.data)
if start_time is not None:
# @TODO: shouldn't that be env_steps per second?
speed = (
cfg.hyperparameters.num_envs
* cfg.hyperparameters.num_steps
* (global_step - measure_burnin)
/ (time.time() - start_time)
)
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
with torch.inference_mode():
logs = {
"critic/qf_loss": logs_dict["qf_loss"].mean(),
"critic/qf_max": logs_dict["qf_max"].mean(),
"critic/qf_min": logs_dict["qf_min"].mean(),
"critic/qf_mean": logs_dict["qf_mean"].mean(),
"critic/embedding_loss": logs_dict["embedding_loss"].mean(),
"critic/critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
"actor/actor_loss": logs_dict["actor_loss"].mean(),
"actor/actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
"actor/kl": logs_dict["kl"].mean(),
"actor/entropy": logs_dict["entropy"].mean(),
"actor/temperature": logs_dict["temperature"].mean(),
"actor/lagrangian": logs_dict["lagrangian"].mean(),
"actor/entropy_loss": logs_dict["entropy_loss"].mean(),
"actor/lagrangian_loss": logs_dict["lagrangian_loss"].mean(),
"train/rewards_batch": data["rewards"].mean(),
}
if cfg.env.type == "maniskill":
logs.update(
{
"train/return": torch.stack(
[info["log_info"]["return"] for info in infos]
).mean(),
"train/episode_len": torch.stack(
[info["log_info"]["episode_len"] for info in infos]
)
.float()
.mean(),
"train/success": torch.stack(
[info["log_info"]["success"] for info in infos]
)
.float()
.mean(),
}
)
if eval_interval > 0 and global_step % eval_interval == 0:
print(f"Evaluating at global step {global_step}")
if stochastic_eval:
eval_avg_return, eval_avg_length, stoch_eval_info = evaluate(
train_state, stochastic_eval=stochastic_eval
)
eval_avg_return, eval_avg_length, eval_info = evaluate(
train_state
)
eval_info = {
**eval_info,
**{f"stoch/{k}": v for k, v in stoch_eval_info.items()},
}
else:
eval_avg_return, eval_avg_length, eval_info = evaluate(
train_state
)
if cfg.env.type in [
"humanoid_bench",
"isaaclab",
"mtbench",
]:
# NOTE: Hacky way of evaluating performance, but just works
obs, _ = envs.reset()
logs["eval/avg_return"] = eval_avg_return
logs["eval/avg_length"] = eval_avg_length
for key, value in eval_info.items():
if isinstance(value, torch.Tensor):
logs[f"eval/{key}"] = value.mean().item()
elif isinstance(value, np.ndarray):
logs[f"eval/{key}"] = value.mean()
else:
logs[f"eval/{key}"] = value
print(
f"Eval return: {eval_avg_return:.2f}, length: {eval_avg_length:.2f}, env steps: {global_step * cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps} success rate: {eval_info.get('success', 0.0):.2f}"
)
wandb.log(
{
"speed": speed,
"frame": global_step
* cfg.hyperparameters.num_envs
* cfg.hyperparameters.num_steps,
**logs,
},
step=global_step
* cfg.hyperparameters.num_envs
* cfg.hyperparameters.num_steps,
)
global_step += 1
pbar.update(n=cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps)
if __name__ == "__main__":
main()