import os import sys os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" os.environ["OMP_NUM_THREADS"] = "1" if sys.platform != "darwin": os.environ["MUJOCO_GL"] = "egl" else: os.environ["MUJOCO_GL"] = "glfw" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" import random import time import tqdm import wandb import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.amp import autocast, GradScaler from tensordict import TensorDict, from_module from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params from hyperparams import get_args from fast_td3 import Actor, Critic torch.set_float32_matmul_precision("high") try: import jax.numpy as jnp except ImportError: pass def main(): args = get_args() print(args) run_name = f"{args.env_name}__{args.exp_name}__{args.seed}" amp_enabled = args.amp and args.cuda and torch.cuda.is_available() amp_device_type = ( "cuda" if args.cuda and torch.cuda.is_available() else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu" ) amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16 scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16) if args.use_wandb: wandb.init( project=args.project, name=run_name, config=vars(args), save_code=True, ) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic if not args.cuda: device = torch.device("cpu") else: if torch.cuda.is_available(): device = torch.device(f"cuda:{args.device_rank}") elif torch.backends.mps.is_available(): device = torch.device(f"mps:{args.device_rank}") else: raise ValueError("No GPU available") print(f"Using device: {device}") if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): from environments.humanoid_bench_env import HumanoidBenchEnv env_type = "humanoid_bench" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) eval_envs = envs render_env = HumanoidBenchEnv( args.env_name, 1, render_mode="rgb_array", device=device ) elif args.env_name.startswith("Isaac-"): from environments.isaaclab_env import IsaacLabEnv env_type = "isaaclab" envs = IsaacLabEnv( args.env_name, device.type, args.num_envs, args.seed, action_bounds=args.action_bounds, ) eval_envs = envs render_env = envs else: from environments.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage env_type = "mujoco_playground" envs, eval_envs, render_env = make_env( args.env_name, args.seed, args.num_envs, args.num_eval_envs, args.device_rank, use_tuned_reward=args.use_tuned_reward, use_domain_randomization=args.use_domain_randomization, use_push_randomization=args.use_push_randomization, ) n_act = envs.num_actions n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0] if envs.asymmetric_obs: n_critic_obs = ( envs.num_privileged_obs if type(envs.num_privileged_obs) == int else envs.num_privileged_obs[0] ) else: n_critic_obs = n_obs action_low, action_high = -1.0, 1.0 if args.obs_normalization: obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device) critic_obs_normalizer = EmpiricalNormalization( shape=n_critic_obs, device=device ) else: obs_normalizer = nn.Identity() critic_obs_normalizer = nn.Identity() actor = Actor( n_obs=n_obs, n_act=n_act, num_envs=args.num_envs, device=device, init_scale=args.init_scale, hidden_dim=args.actor_hidden_dim, ) actor_detach = Actor( n_obs=n_obs, n_act=n_act, num_envs=args.num_envs, device=device, init_scale=args.init_scale, hidden_dim=args.actor_hidden_dim, ) # Copy params to actor_detach without grad from_module(actor).data.to_module(actor_detach) policy = actor_detach.explore qnet = Critic( n_obs=n_critic_obs, n_act=n_act, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, hidden_dim=args.critic_hidden_dim, device=device, ) qnet_target = Critic( n_obs=n_critic_obs, n_act=n_act, num_atoms=args.num_atoms, v_min=args.v_min, v_max=args.v_max, hidden_dim=args.critic_hidden_dim, device=device, ) qnet_target.load_state_dict(qnet.state_dict()) q_optimizer = optim.AdamW( list(qnet.parameters()), lr=args.critic_learning_rate, weight_decay=args.weight_decay, ) actor_optimizer = optim.AdamW( list(actor.parameters()), lr=args.actor_learning_rate, weight_decay=args.weight_decay, ) rb = SimpleReplayBuffer( n_env=args.num_envs, buffer_size=args.buffer_size, n_obs=n_obs, n_act=n_act, n_critic_obs=n_critic_obs, asymmetric_obs=envs.asymmetric_obs, playground_mode=env_type == "mujoco_playground", n_steps=args.num_steps, gamma=args.gamma, device=device, ) policy_noise = args.policy_noise noise_clip = args.noise_clip def evaluate(): obs_normalizer.eval() num_eval_envs = eval_envs.num_envs episode_returns = torch.zeros(num_eval_envs, device=device) episode_lengths = torch.zeros(num_eval_envs, device=device) done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device) if env_type == "isaaclab": obs = eval_envs.reset(random_start_init=False) else: obs = eval_envs.reset() # Run for a fixed number of steps for _ in range(eval_envs.max_episode_steps): with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): obs = normalize_obs(obs) actions = actor(obs) next_obs, rewards, dones, _ = eval_envs.step(actions.float()) episode_returns = torch.where( ~done_masks, episode_returns + rewards, episode_returns ) episode_lengths = torch.where( ~done_masks, episode_lengths + 1, episode_lengths ) done_masks = torch.logical_or(done_masks, dones) if done_masks.all(): break obs = next_obs obs_normalizer.train() return episode_returns.mean().item(), episode_lengths.mean().item() def render_with_rollout(): obs_normalizer.eval() # Quick rollout for rendering if env_type == "humanoid_bench": obs = render_env.reset() renders = [render_env.render()] elif env_type == "isaaclab": raise NotImplementedError( "We don't support rendering for IsaacLab environments" ) else: obs = render_env.reset() render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) renders = [render_env.state] for i in range(render_env.max_episode_steps): with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): obs = normalize_obs(obs) actions = actor(obs) next_obs, _, done, _ = render_env.step(actions.float()) if env_type == "mujoco_playground": render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]]) if i % 2 == 0: if env_type == "humanoid_bench": renders.append(render_env.render()) else: renders.append(render_env.state) if done.any(): break obs = next_obs if env_type == "mujoco_playground": renders = render_env.render_trajectory(renders) obs_normalizer.train() return renders def update_main(data, logs_dict): with autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): observations = data["observations"] next_observations = data["next"]["observations"] if envs.asymmetric_obs: critic_observations = data["critic_observations"] next_critic_observations = data["next"]["critic_observations"] else: critic_observations = observations next_critic_observations = next_observations actions = data["actions"] rewards = data["next"]["rewards"] dones = data["next"]["dones"].bool() truncations = data["next"]["truncations"].bool() if args.disable_bootstrap: bootstrap = (~dones).float() else: bootstrap = (truncations | ~dones).float() clipped_noise = torch.randn_like(actions) clipped_noise = clipped_noise.mul(policy_noise).clamp( -noise_clip, noise_clip ) next_state_actions = (actor(next_observations) + clipped_noise).clamp( action_low, action_high ) with torch.no_grad(): qf1_next_target_projected, qf2_next_target_projected = ( qnet_target.projection( next_critic_observations, next_state_actions, rewards, bootstrap, args.gamma, ) ) qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected) qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected) if args.use_cdq: qf_next_target_dist = torch.where( qf1_next_target_value.unsqueeze(1) < qf2_next_target_value.unsqueeze(1), qf1_next_target_projected, qf2_next_target_projected, ) qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist else: qf1_next_target_dist, qf2_next_target_dist = ( qf1_next_target_projected, qf2_next_target_projected, ) qf1, qf2 = qnet(critic_observations, actions) qf1_loss = -torch.sum( qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1 ).mean() qf2_loss = -torch.sum( qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1 ).mean() qf_loss = qf1_loss + qf2_loss q_optimizer.zero_grad(set_to_none=True) scaler.scale(qf_loss).backward() scaler.unscale_(q_optimizer) critic_grad_norm = torch.nn.utils.clip_grad_norm_( qnet.parameters(), max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), ) scaler.step(q_optimizer) scaler.update() logs_dict["buffer_rewards"] = rewards.mean() logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["qf_loss"] = qf_loss.detach() logs_dict["qf_max"] = qf1_next_target_value.max().detach() logs_dict["qf_min"] = qf1_next_target_value.min().detach() return logs_dict def update_pol(data, logs_dict): with autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): critic_observations = ( data["critic_observations"] if envs.asymmetric_obs else data["observations"] ) qf1, qf2 = qnet(critic_observations, actor(data["observations"])) qf1_value = qnet.get_value(F.softmax(qf1, dim=1)) qf2_value = qnet.get_value(F.softmax(qf2, dim=1)) if args.use_cdq: qf_value = torch.minimum(qf1_value, qf2_value) else: qf_value = (qf1_value + qf2_value) / 2.0 actor_loss = -qf_value.mean() actor_optimizer.zero_grad(set_to_none=True) scaler.scale(actor_loss).backward() scaler.unscale_(actor_optimizer) actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor.parameters(), max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), ) scaler.step(actor_optimizer) scaler.update() logs_dict["actor_grad_norm"] = actor_grad_norm.detach() logs_dict["actor_loss"] = actor_loss.detach() return logs_dict if args.compile: mode = None update_main = torch.compile(update_main, mode=mode) update_pol = torch.compile(update_pol, mode=mode) policy = torch.compile(policy, mode=mode) normalize_obs = torch.compile(obs_normalizer.forward, mode=mode) normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode) else: normalize_obs = obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward if envs.asymmetric_obs: obs, critic_obs = envs.reset_with_critic_obs() critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float) else: obs = envs.reset() if args.checkpoint_path: # Load checkpoint if specified torch_checkpoint = torch.load( f"{args.checkpoint_path}", map_location=device, weights_only=False ) actor.load_state_dict(torch_checkpoint["actor_state_dict"]) obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"]) critic_obs_normalizer.load_state_dict( torch_checkpoint["critic_obs_normalizer_state"] ) qnet.load_state_dict(torch_checkpoint["qnet_state_dict"]) qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"]) global_step = torch_checkpoint["global_step"] else: global_step = 0 dones = None pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step) start_time = None desc = "" while global_step < args.total_timesteps: logs_dict = TensorDict() if ( start_time is None and global_step >= args.measure_burnin + args.learning_starts ): start_time = time.time() measure_burnin = global_step with torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): norm_obs = normalize_obs(obs) actions = policy(obs=norm_obs, dones=dones) next_obs, rewards, dones, infos = envs.step(actions.float()) truncations = infos["time_outs"] if envs.asymmetric_obs: next_critic_obs = infos["observations"]["critic"] # Compute 'true' next_obs and next_critic_obs for saving true_next_obs = torch.where( dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs ) if envs.asymmetric_obs: true_next_critic_obs = torch.where( dones[:, None] > 0, infos["observations"]["raw"]["critic_obs"], next_critic_obs, ) transition = TensorDict( { "observations": obs, "actions": torch.as_tensor(actions, device=device, dtype=torch.float), "next": { "observations": true_next_obs, "rewards": torch.as_tensor( rewards, device=device, dtype=torch.float ), "truncations": truncations.long(), "dones": dones.long(), }, }, batch_size=(envs.num_envs,), device=device, ) if envs.asymmetric_obs: transition["critic_observations"] = critic_obs transition["next"]["critic_observations"] = true_next_critic_obs obs = next_obs if envs.asymmetric_obs: critic_obs = next_critic_obs rb.extend(transition) batch_size = args.batch_size // args.num_envs if global_step > args.learning_starts: for i in range(args.num_updates): data = rb.sample(batch_size) data["observations"] = normalize_obs(data["observations"]) data["next"]["observations"] = normalize_obs( data["next"]["observations"] ) if envs.asymmetric_obs: data["critic_observations"] = normalize_critic_obs( data["critic_observations"] ) data["next"]["critic_observations"] = normalize_critic_obs( data["next"]["critic_observations"] ) logs_dict = update_main(data, logs_dict) if args.num_updates > 1: if i % args.policy_frequency == 1: logs_dict = update_pol(data, logs_dict) else: if global_step % args.policy_frequency == 0: logs_dict = update_pol(data, logs_dict) for param, target_param in zip( qnet.parameters(), qnet_target.parameters() ): target_param.data.copy_( args.tau * param.data + (1 - args.tau) * target_param.data ) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "actor_loss": logs_dict["actor_loss"].mean(), "qf_loss": logs_dict["qf_loss"].mean(), "qf_max": logs_dict["qf_max"].mean(), "qf_min": logs_dict["qf_min"].mean(), "actor_grad_norm": logs_dict["actor_grad_norm"].mean(), "critic_grad_norm": logs_dict["critic_grad_norm"].mean(), "buffer_rewards": logs_dict["buffer_rewards"].mean(), "env_rewards": rewards.mean(), } if args.eval_interval > 0 and global_step % args.eval_interval == 0: print(f"Evaluating at global step {global_step}") eval_avg_return, eval_avg_length = evaluate() if env_type in ["humanoid_bench", "isaaclab"]: # NOTE: Hacky way of evaluating performance, but just works obs = envs.reset() logs["eval_avg_return"] = eval_avg_return logs["eval_avg_length"] = eval_avg_length if ( args.render_interval > 0 and global_step % args.render_interval == 0 ): renders = render_with_rollout() if args.use_wandb: wandb.log( { "render_video": wandb.Video( np.array(renders).transpose( 0, 3, 1, 2 ), # Convert to (T, C, H, W) format fps=30, format="gif", ) }, step=global_step, ) if args.use_wandb: wandb.log( { "speed": speed, "frame": global_step * args.num_envs, **logs, }, step=global_step, ) if ( args.save_interval > 0 and global_step > 0 and global_step % args.save_interval == 0 ): print(f"Saving model at global step {global_step}") save_params( global_step, actor, qnet, qnet_target, obs_normalizer, critic_obs_normalizer, args, f"models/{run_name}_{global_step}.pt", ) global_step += 1 pbar.update(1) save_params( global_step, actor, qnet, qnet_target, obs_normalizer, critic_obs_normalizer, args, f"models/{run_name}_final.pt", ) if __name__ == "__main__": main()