reppo/reppo_alg/torchrl/fast_td3.py
2025-07-21 18:31:20 -04:00

692 lines
25 KiB
Python

import os
import sys
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if sys.platform != "darwin":
os.environ["MUJOCO_GL"] = "egl"
else:
os.environ["MUJOCO_GL"] = "glfw"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
import math
import random
import time
import numpy as np
import tqdm
import wandb
try:
# Required for avoiding IsaacGym import error
import isaacgym
except ImportError:
pass
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from reppo_alg.torchrl.reppo import (
EmpiricalNormalization,
PerTaskRewardNormalizer,
RewardNormalizer,
SimpleReplayBuffer,
save_params,
)
from hyperparams import get_args
from tensordict import TensorDict
from torch.amp import GradScaler, autocast
torch.set_float32_matmul_precision("high")
def main():
args = get_args()
print(args)
run_name = f"{args.env_name}__{args.exp_name}__{args.seed}"
amp_enabled = args.amp and args.cuda and torch.cuda.is_available()
amp_device_type = (
"cuda"
if args.cuda and torch.cuda.is_available()
else "mps"
if args.cuda and torch.backends.mps.is_available()
else "cpu"
)
amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16
scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)
if args.use_wandb:
wandb.init(
project=args.project,
name=run_name,
config=vars(args),
save_code=True,
)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
if not args.cuda:
device = torch.device("cpu")
else:
if torch.cuda.is_available():
device = torch.device(f"cuda:{args.device_rank}")
elif torch.backends.mps.is_available():
device = torch.device(f"mps:{args.device_rank}")
else:
raise ValueError("No GPU available")
print(f"Using device: {device}")
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import (
HumanoidBenchEnv,
)
env_type = "humanoid_bench"
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
eval_envs = envs
elif args.env_name.startswith("Isaac-"):
from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv
env_type = "isaaclab"
envs = IsaacLabEnv(
args.env_name,
device.type,
args.num_envs,
args.seed,
action_bounds=args.action_bounds,
)
eval_envs = envs
elif args.env_name.startswith("MTBench-"):
from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv
env_name = "-".join(args.env_name.split("-")[1:])
env_type = "mtbench"
envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed)
eval_envs = envs
else:
from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env
# TODO: Check if re-using same envs for eval could reduce memory usage
env_type = "mujoco_playground"
envs, eval_envs = make_env(
args.env_name,
args.seed,
args.num_envs,
args.num_eval_envs,
args.device_rank,
use_tuned_reward=args.use_tuned_reward,
use_domain_randomization=args.use_domain_randomization,
use_push_randomization=args.use_push_randomization,
)
n_act = envs.num_actions
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
if envs.asymmetric_obs:
n_critic_obs = (
envs.num_privileged_obs
if isinstance(envs.num_privileged_obs, int)
else envs.num_privileged_obs[0]
)
else:
n_critic_obs = n_obs
action_low, action_high = -1.0, 1.0
if args.obs_normalization:
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
critic_obs_normalizer = EmpiricalNormalization(
shape=n_critic_obs, device=device
)
else:
obs_normalizer = nn.Identity()
critic_obs_normalizer = nn.Identity()
if args.reward_normalization:
if env_type in ["mtbench"]:
reward_normalizer = PerTaskRewardNormalizer(
num_tasks=envs.num_tasks,
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = RewardNormalizer(
gamma=args.gamma,
device=device,
g_max=min(abs(args.v_min), abs(args.v_max)),
)
else:
reward_normalizer = nn.Identity()
actor_kwargs = {
"n_obs": n_obs,
"n_act": n_act,
"num_envs": args.num_envs,
"device": device,
"init_scale": args.init_scale,
"hidden_dim": args.actor_hidden_dim,
}
critic_kwargs = {
"n_obs": n_critic_obs,
"n_act": n_act,
"num_atoms": args.num_atoms,
"v_min": args.v_min,
"v_max": args.v_max,
"hidden_dim": args.critic_hidden_dim,
"device": device,
}
if env_type == "mtbench":
actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim
critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim
actor_kwargs["num_tasks"] = envs.num_tasks
actor_kwargs["task_embedding_dim"] = args.task_embedding_dim
critic_kwargs["num_tasks"] = envs.num_tasks
critic_kwargs["task_embedding_dim"] = args.task_embedding_dim
if args.agent == "fasttd3":
if env_type in ["mtbench"]:
from reppo_alg.network_utils.fast_td3_nets import (
MultiTaskActor,
MultiTaskCritic,
)
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from reppo_alg.network_utils.fast_td3_nets import Actor, Critic
actor_cls = Actor
critic_cls = Critic
print("Using FastTD3")
elif args.agent == "fasttd3_simbav2":
if env_type in ["mtbench"]:
from reppo_alg.network_utils.fast_td3_nets_simbav2 import (
MultiTaskActor,
MultiTaskCritic,
)
actor_cls = MultiTaskActor
critic_cls = MultiTaskCritic
else:
from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic
actor_cls = Actor
critic_cls = Critic
print("Using FastTD3 + SimbaV2")
actor_kwargs.pop("init_scale")
actor_kwargs.update(
{
"scaler_init": math.sqrt(2.0 / args.actor_hidden_dim),
"scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim),
"alpha_init": 1.0 / (args.actor_num_blocks + 1),
"alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim),
"expansion": 4,
"c_shift": 3.0,
"num_blocks": args.actor_num_blocks,
}
)
critic_kwargs.update(
{
"scaler_init": math.sqrt(2.0 / args.critic_hidden_dim),
"scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim),
"alpha_init": 1.0 / (args.critic_num_blocks + 1),
"alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim),
"num_blocks": args.critic_num_blocks,
"expansion": 4,
"c_shift": 3.0,
}
)
else:
raise ValueError(f"Agent {args.agent} not supported")
actor = actor_cls(**actor_kwargs)
if env_type in ["mtbench"]:
# Python 3.8 doesn't support 'from_module' in tensordict
policy = actor.explore
else:
from tensordict import from_module
actor_detach = actor_cls(**actor_kwargs)
# Copy params to actor_detach without grad
from_module(actor).data.to_module(actor_detach)
policy = actor_detach.explore
qnet = critic_cls(**critic_kwargs)
qnet_target = critic_cls(**critic_kwargs)
qnet_target.load_state_dict(qnet.state_dict())
q_optimizer = optim.AdamW(
list(qnet.parameters()),
lr=torch.tensor(args.critic_learning_rate, device=device),
weight_decay=args.weight_decay,
)
actor_optimizer = optim.AdamW(
list(actor.parameters()),
lr=torch.tensor(args.actor_learning_rate, device=device),
weight_decay=args.weight_decay,
)
# Add learning rate schedulers
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
q_optimizer,
T_max=args.total_timesteps,
eta_min=torch.tensor(args.critic_learning_rate_end, device=device),
)
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
actor_optimizer,
T_max=args.total_timesteps,
eta_min=torch.tensor(args.actor_learning_rate_end, device=device),
)
rb = SimpleReplayBuffer(
n_env=args.num_envs,
buffer_size=args.buffer_size,
n_obs=n_obs,
n_act=n_act,
n_critic_obs=n_critic_obs,
asymmetric_obs=envs.asymmetric_obs,
playground_mode=env_type == "mujoco_playground",
n_steps=args.num_steps,
gamma=args.gamma,
device=device,
)
policy_noise = args.policy_noise
noise_clip = args.noise_clip
def evaluate():
obs_normalizer.eval()
num_eval_envs = eval_envs.num_envs
episode_returns = torch.zeros(num_eval_envs, device=device)
episode_lengths = torch.zeros(num_eval_envs, device=device)
done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)
if env_type == "isaaclab":
obs = eval_envs.reset(random_start_init=False)
else:
obs = eval_envs.reset()
# Run for a fixed number of steps
for i in range(eval_envs.max_episode_steps):
with (
torch.no_grad(),
autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
),
):
obs = normalize_obs(obs)
actions = actor(obs)
next_obs, rewards, dones, _, infos = eval_envs.step(actions.float())
if env_type == "mtbench":
# We only report success rate in MTBench evaluation
rewards = (
infos["episode"]["success"].float() if "episode" in infos else 0.0
)
episode_returns = torch.where(
~done_masks, episode_returns + rewards, episode_returns
)
episode_lengths = torch.where(
~done_masks, episode_lengths + 1, episode_lengths
)
if env_type == "mtbench" and "episode" in infos:
dones = dones | infos["episode"]["success"]
done_masks = torch.logical_or(done_masks, dones)
if done_masks.all():
break
obs = next_obs
obs_normalizer.train()
return episode_returns.mean().item(), episode_lengths.mean().item()
def update_main(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
observations = data["observations"]
next_observations = data["next"]["observations"]
if envs.asymmetric_obs:
critic_observations = data["critic_observations"]
next_critic_observations = data["next"]["critic_observations"]
else:
critic_observations = observations
next_critic_observations = next_observations
actions = data["actions"]
rewards = data["next"]["rewards"]
dones = data["next"]["dones"].bool()
truncations = data["next"]["truncations"].bool()
if args.disable_bootstrap:
bootstrap = (~dones).float()
else:
bootstrap = (truncations | ~dones).float()
clipped_noise = torch.randn_like(actions)
clipped_noise = clipped_noise.mul(policy_noise).clamp(
-noise_clip, noise_clip
)
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
action_low, action_high
)
discount = args.gamma ** data["next"]["effective_n_steps"]
with torch.no_grad():
qf1_next_target_projected, qf2_next_target_projected = (
qnet_target.projection(
next_critic_observations,
next_state_actions,
rewards,
bootstrap,
discount,
)
)
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)
if args.use_cdq:
qf_next_target_dist = torch.where(
qf1_next_target_value.unsqueeze(1)
< qf2_next_target_value.unsqueeze(1),
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist
else:
qf1_next_target_dist, qf2_next_target_dist = (
qf1_next_target_projected,
qf2_next_target_projected,
)
qf1, qf2 = qnet(critic_observations, actions)
qf1_loss = -torch.sum(
qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1
).mean()
qf2_loss = -torch.sum(
qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1
).mean()
qf_loss = qf1_loss + qf2_loss
q_optimizer.zero_grad(set_to_none=True)
scaler.scale(qf_loss).backward()
scaler.unscale_(q_optimizer)
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
qnet.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
scaler.step(q_optimizer)
scaler.update()
q_scheduler.step()
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
logs_dict["qf_loss"] = qf_loss.detach()
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
logs_dict["qf_min"] = qf1_next_target_value.min().detach()
return logs_dict
def update_pol(data, logs_dict):
with autocast(
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
):
critic_observations = (
data["critic_observations"]
if envs.asymmetric_obs
else data["observations"]
)
qf1, qf2 = qnet(critic_observations, actor(data["observations"]))
qf1_value = qnet.get_value(F.softmax(qf1, dim=1))
qf2_value = qnet.get_value(F.softmax(qf2, dim=1))
if args.use_cdq:
qf_value = torch.minimum(qf1_value, qf2_value)
else:
qf_value = (qf1_value + qf2_value) / 2.0
actor_loss = -qf_value.mean()
actor_optimizer.zero_grad(set_to_none=True)
scaler.scale(actor_loss).backward()
scaler.unscale_(actor_optimizer)
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
actor.parameters(),
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
)
scaler.step(actor_optimizer)
scaler.update()
actor_scheduler.step()
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
logs_dict["actor_loss"] = actor_loss.detach()
return logs_dict
if args.compile:
mode = None
update_main = torch.compile(update_main, mode=mode)
update_pol = torch.compile(update_pol, mode=mode)
policy = torch.compile(policy, mode=mode)
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
if args.reward_normalization:
update_stats = torch.compile(reward_normalizer.update_stats, mode=mode)
normalize_reward = torch.compile(reward_normalizer.forward, mode=mode)
else:
normalize_obs = obs_normalizer.forward
normalize_critic_obs = critic_obs_normalizer.forward
if args.reward_normalization:
update_stats = reward_normalizer.update_stats
normalize_reward = reward_normalizer.forward
if envs.asymmetric_obs:
obs, critic_obs = envs.reset_with_critic_obs()
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
else:
obs = envs.reset()
if args.checkpoint_path:
# Load checkpoint if specified
torch_checkpoint = torch.load(
f"{args.checkpoint_path}", map_location=device, weights_only=False
)
actor.load_state_dict(torch_checkpoint["actor_state_dict"])
obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
critic_obs_normalizer.load_state_dict(
torch_checkpoint["critic_obs_normalizer_state"]
)
qnet.load_state_dict(torch_checkpoint["qnet_state_dict"])
qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
global_step = torch_checkpoint["global_step"]
else:
global_step = 0
dones = None
pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)
start_time = None
desc = ""
while global_step < args.total_timesteps:
logs_dict = TensorDict()
if (
start_time is None
and global_step >= args.measure_burnin + args.learning_starts
):
start_time = time.time()
measure_burnin = global_step
with (
torch.no_grad(),
autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled),
):
norm_obs = normalize_obs(obs)
actions = policy(obs=norm_obs, dones=dones)
next_obs, rewards, dones, _, infos = envs.step(actions.float())
print(infos["time_outs"])
truncations = infos["time_outs"]
if args.reward_normalization:
if env_type == "mtbench":
task_ids_one_hot = obs[..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
update_stats(rewards, dones.float(), task_ids=task_indices)
else:
update_stats(rewards, dones.float())
if envs.asymmetric_obs:
next_critic_obs = infos["observations"]["critic"]
# Compute 'true' next_obs and next_critic_obs for saving
true_next_obs = torch.where(
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
)
if envs.asymmetric_obs:
true_next_critic_obs = torch.where(
dones[:, None] > 0,
infos["observations"]["raw"]["critic_obs"],
next_critic_obs,
)
transition = TensorDict(
{
"observations": obs,
"actions": torch.as_tensor(actions, device=device, dtype=torch.float),
"next": {
"observations": true_next_obs,
"rewards": torch.as_tensor(
rewards, device=device, dtype=torch.float
),
"truncations": truncations.long(),
"dones": dones.long(),
},
},
batch_size=(envs.num_envs,),
device=device,
)
if envs.asymmetric_obs:
transition["critic_observations"] = critic_obs
transition["next"]["critic_observations"] = true_next_critic_obs
obs = next_obs
if envs.asymmetric_obs:
critic_obs = next_critic_obs
rb.extend(transition)
batch_size = args.batch_size // args.num_envs
if global_step > args.learning_starts:
for i in range(args.num_updates):
data = rb.sample(batch_size)
data["observations"] = normalize_obs(data["observations"])
data["next"]["observations"] = normalize_obs(
data["next"]["observations"]
)
raw_rewards = data["next"]["rewards"]
if env_type in ["mtbench"] and args.reward_normalization:
# Multi-task reward normalization
task_ids_one_hot = data["observations"][..., -envs.num_tasks :]
task_indices = torch.argmax(task_ids_one_hot, dim=1)
data["next"]["rewards"] = normalize_reward(
raw_rewards, task_ids=task_indices
)
else:
data["next"]["rewards"] = normalize_reward(raw_rewards)
if envs.asymmetric_obs:
data["critic_observations"] = normalize_critic_obs(
data["critic_observations"]
)
data["next"]["critic_observations"] = normalize_critic_obs(
data["next"]["critic_observations"]
)
logs_dict = update_main(data, logs_dict)
if args.num_updates > 1:
if i % args.policy_frequency == 1:
logs_dict = update_pol(data, logs_dict)
else:
if global_step % args.policy_frequency == 0:
logs_dict = update_pol(data, logs_dict)
for param, target_param in zip(
qnet.parameters(), qnet_target.parameters()
):
target_param.data.copy_(
args.tau * param.data + (1 - args.tau) * target_param.data
)
if global_step % 100 == 0 and start_time is not None:
speed = (global_step - measure_burnin) / (time.time() - start_time)
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
with torch.no_grad():
logs = {
"actor_loss": logs_dict["actor_loss"].mean(),
"qf_loss": logs_dict["qf_loss"].mean(),
"qf_max": logs_dict["qf_max"].mean(),
"qf_min": logs_dict["qf_min"].mean(),
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
"env_rewards": rewards.mean(),
"buffer_rewards": raw_rewards.mean(),
}
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
print(f"Evaluating at global step {global_step}")
eval_avg_return, eval_avg_length = evaluate()
if env_type in ["humanoid_bench", "isaaclab", "mtbench"]:
# NOTE: Hacky way of evaluating performance, but just works
obs = envs.reset()
logs["eval_avg_return"] = eval_avg_return
logs["eval_avg_length"] = eval_avg_length
if args.use_wandb:
wandb.log(
{
"speed": speed,
"frame": global_step * args.num_envs,
"critic_lr": q_scheduler.get_last_lr()[0],
"actor_lr": actor_scheduler.get_last_lr()[0],
**logs,
},
step=global_step,
)
if (
args.save_interval > 0
and global_step > 0
and global_step % args.save_interval == 0
):
print(f"Saving model at global step {global_step}")
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_{global_step}.pt",
)
global_step += 1
pbar.update(1)
save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
f"models/{run_name}_final.pt",
)
if __name__ == "__main__":
main()