import os import sys os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1" os.environ["OMP_NUM_THREADS"] = "1" if sys.platform != "darwin": os.environ["MUJOCO_GL"] = "egl" else: os.environ["MUJOCO_GL"] = "glfw" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" import math import random import time import numpy as np import tqdm import wandb try: # Required for avoiding IsaacGym import error import isaacgym except ImportError: pass import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from reppo_alg.torchrl.reppo import ( EmpiricalNormalization, PerTaskRewardNormalizer, RewardNormalizer, SimpleReplayBuffer, save_params, ) from hyperparams import get_args from tensordict import TensorDict from torch.amp import GradScaler, autocast torch.set_float32_matmul_precision("high") def main(): args = get_args() print(args) run_name = f"{args.env_name}__{args.exp_name}__{args.seed}" amp_enabled = args.amp and args.cuda and torch.cuda.is_available() amp_device_type = ( "cuda" if args.cuda and torch.cuda.is_available() else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu" ) amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16 scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16) if args.use_wandb: wandb.init( project=args.project, name=run_name, config=vars(args), save_code=True, ) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic if not args.cuda: device = torch.device("cpu") else: if torch.cuda.is_available(): device = torch.device(f"cuda:{args.device_rank}") elif torch.backends.mps.is_available(): device = torch.device(f"mps:{args.device_rank}") else: raise ValueError("No GPU available") print(f"Using device: {device}") if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"): from reppo_alg.env_utils.torch_wrappers.humanoid_bench_env import ( HumanoidBenchEnv, ) env_type = "humanoid_bench" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device) eval_envs = envs elif args.env_name.startswith("Isaac-"): from reppo_alg.env_utils.torch_wrappers.isaaclab_env import IsaacLabEnv env_type = "isaaclab" envs = IsaacLabEnv( args.env_name, device.type, args.num_envs, args.seed, action_bounds=args.action_bounds, ) eval_envs = envs elif args.env_name.startswith("MTBench-"): from reppo_alg.env_utils.torch_wrappers.mtbench_env import MTBenchEnv env_name = "-".join(args.env_name.split("-")[1:]) env_type = "mtbench" envs = MTBenchEnv(env_name, args.device_rank, args.num_envs, args.seed) eval_envs = envs else: from reppo_alg.env_utils.torch_wrappers.mujoco_playground_env import make_env # TODO: Check if re-using same envs for eval could reduce memory usage env_type = "mujoco_playground" envs, eval_envs = make_env( args.env_name, args.seed, args.num_envs, args.num_eval_envs, args.device_rank, use_tuned_reward=args.use_tuned_reward, use_domain_randomization=args.use_domain_randomization, use_push_randomization=args.use_push_randomization, ) n_act = envs.num_actions n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0] if envs.asymmetric_obs: n_critic_obs = ( envs.num_privileged_obs if isinstance(envs.num_privileged_obs, int) else envs.num_privileged_obs[0] ) else: n_critic_obs = n_obs action_low, action_high = -1.0, 1.0 if args.obs_normalization: obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device) critic_obs_normalizer = EmpiricalNormalization( shape=n_critic_obs, device=device ) else: obs_normalizer = nn.Identity() critic_obs_normalizer = nn.Identity() if args.reward_normalization: if env_type in ["mtbench"]: reward_normalizer = PerTaskRewardNormalizer( num_tasks=envs.num_tasks, gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max)), ) else: reward_normalizer = RewardNormalizer( gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max)), ) else: reward_normalizer = nn.Identity() actor_kwargs = { "n_obs": n_obs, "n_act": n_act, "num_envs": args.num_envs, "device": device, "init_scale": args.init_scale, "hidden_dim": args.actor_hidden_dim, } critic_kwargs = { "n_obs": n_critic_obs, "n_act": n_act, "num_atoms": args.num_atoms, "v_min": args.v_min, "v_max": args.v_max, "hidden_dim": args.critic_hidden_dim, "device": device, } if env_type == "mtbench": actor_kwargs["n_obs"] = n_obs - envs.num_tasks + args.task_embedding_dim critic_kwargs["n_obs"] = n_critic_obs - envs.num_tasks + args.task_embedding_dim actor_kwargs["num_tasks"] = envs.num_tasks actor_kwargs["task_embedding_dim"] = args.task_embedding_dim critic_kwargs["num_tasks"] = envs.num_tasks critic_kwargs["task_embedding_dim"] = args.task_embedding_dim if args.agent == "fasttd3": if env_type in ["mtbench"]: from reppo_alg.network_utils.fast_td3_nets import ( MultiTaskActor, MultiTaskCritic, ) actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: from reppo_alg.network_utils.fast_td3_nets import Actor, Critic actor_cls = Actor critic_cls = Critic print("Using FastTD3") elif args.agent == "fasttd3_simbav2": if env_type in ["mtbench"]: from reppo_alg.network_utils.fast_td3_nets_simbav2 import ( MultiTaskActor, MultiTaskCritic, ) actor_cls = MultiTaskActor critic_cls = MultiTaskCritic else: from reppo_alg.network_utils.fast_td3_nets_simbav2 import Actor, Critic actor_cls = Actor critic_cls = Critic print("Using FastTD3 + SimbaV2") actor_kwargs.pop("init_scale") actor_kwargs.update( { "scaler_init": math.sqrt(2.0 / args.actor_hidden_dim), "scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim), "alpha_init": 1.0 / (args.actor_num_blocks + 1), "alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim), "expansion": 4, "c_shift": 3.0, "num_blocks": args.actor_num_blocks, } ) critic_kwargs.update( { "scaler_init": math.sqrt(2.0 / args.critic_hidden_dim), "scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim), "alpha_init": 1.0 / (args.critic_num_blocks + 1), "alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim), "num_blocks": args.critic_num_blocks, "expansion": 4, "c_shift": 3.0, } ) else: raise ValueError(f"Agent {args.agent} not supported") actor = actor_cls(**actor_kwargs) if env_type in ["mtbench"]: # Python 3.8 doesn't support 'from_module' in tensordict policy = actor.explore else: from tensordict import from_module actor_detach = actor_cls(**actor_kwargs) # Copy params to actor_detach without grad from_module(actor).data.to_module(actor_detach) policy = actor_detach.explore qnet = critic_cls(**critic_kwargs) qnet_target = critic_cls(**critic_kwargs) qnet_target.load_state_dict(qnet.state_dict()) q_optimizer = optim.AdamW( list(qnet.parameters()), lr=torch.tensor(args.critic_learning_rate, device=device), weight_decay=args.weight_decay, ) actor_optimizer = optim.AdamW( list(actor.parameters()), lr=torch.tensor(args.actor_learning_rate, device=device), weight_decay=args.weight_decay, ) # Add learning rate schedulers q_scheduler = optim.lr_scheduler.CosineAnnealingLR( q_optimizer, T_max=args.total_timesteps, eta_min=torch.tensor(args.critic_learning_rate_end, device=device), ) actor_scheduler = optim.lr_scheduler.CosineAnnealingLR( actor_optimizer, T_max=args.total_timesteps, eta_min=torch.tensor(args.actor_learning_rate_end, device=device), ) rb = SimpleReplayBuffer( n_env=args.num_envs, buffer_size=args.buffer_size, n_obs=n_obs, n_act=n_act, n_critic_obs=n_critic_obs, asymmetric_obs=envs.asymmetric_obs, playground_mode=env_type == "mujoco_playground", n_steps=args.num_steps, gamma=args.gamma, device=device, ) policy_noise = args.policy_noise noise_clip = args.noise_clip def evaluate(): obs_normalizer.eval() num_eval_envs = eval_envs.num_envs episode_returns = torch.zeros(num_eval_envs, device=device) episode_lengths = torch.zeros(num_eval_envs, device=device) done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device) if env_type == "isaaclab": obs = eval_envs.reset(random_start_init=False) else: obs = eval_envs.reset() # Run for a fixed number of steps for i in range(eval_envs.max_episode_steps): with ( torch.no_grad(), autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ), ): obs = normalize_obs(obs) actions = actor(obs) next_obs, rewards, dones, _, infos = eval_envs.step(actions.float()) if env_type == "mtbench": # We only report success rate in MTBench evaluation rewards = ( infos["episode"]["success"].float() if "episode" in infos else 0.0 ) episode_returns = torch.where( ~done_masks, episode_returns + rewards, episode_returns ) episode_lengths = torch.where( ~done_masks, episode_lengths + 1, episode_lengths ) if env_type == "mtbench" and "episode" in infos: dones = dones | infos["episode"]["success"] done_masks = torch.logical_or(done_masks, dones) if done_masks.all(): break obs = next_obs obs_normalizer.train() return episode_returns.mean().item(), episode_lengths.mean().item() def update_main(data, logs_dict): with autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): observations = data["observations"] next_observations = data["next"]["observations"] if envs.asymmetric_obs: critic_observations = data["critic_observations"] next_critic_observations = data["next"]["critic_observations"] else: critic_observations = observations next_critic_observations = next_observations actions = data["actions"] rewards = data["next"]["rewards"] dones = data["next"]["dones"].bool() truncations = data["next"]["truncations"].bool() if args.disable_bootstrap: bootstrap = (~dones).float() else: bootstrap = (truncations | ~dones).float() clipped_noise = torch.randn_like(actions) clipped_noise = clipped_noise.mul(policy_noise).clamp( -noise_clip, noise_clip ) next_state_actions = (actor(next_observations) + clipped_noise).clamp( action_low, action_high ) discount = args.gamma ** data["next"]["effective_n_steps"] with torch.no_grad(): qf1_next_target_projected, qf2_next_target_projected = ( qnet_target.projection( next_critic_observations, next_state_actions, rewards, bootstrap, discount, ) ) qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected) qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected) if args.use_cdq: qf_next_target_dist = torch.where( qf1_next_target_value.unsqueeze(1) < qf2_next_target_value.unsqueeze(1), qf1_next_target_projected, qf2_next_target_projected, ) qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist else: qf1_next_target_dist, qf2_next_target_dist = ( qf1_next_target_projected, qf2_next_target_projected, ) qf1, qf2 = qnet(critic_observations, actions) qf1_loss = -torch.sum( qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1 ).mean() qf2_loss = -torch.sum( qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1 ).mean() qf_loss = qf1_loss + qf2_loss q_optimizer.zero_grad(set_to_none=True) scaler.scale(qf_loss).backward() scaler.unscale_(q_optimizer) critic_grad_norm = torch.nn.utils.clip_grad_norm_( qnet.parameters(), max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), ) scaler.step(q_optimizer) scaler.update() q_scheduler.step() logs_dict["critic_grad_norm"] = critic_grad_norm.detach() logs_dict["qf_loss"] = qf_loss.detach() logs_dict["qf_max"] = qf1_next_target_value.max().detach() logs_dict["qf_min"] = qf1_next_target_value.min().detach() return logs_dict def update_pol(data, logs_dict): with autocast( device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled ): critic_observations = ( data["critic_observations"] if envs.asymmetric_obs else data["observations"] ) qf1, qf2 = qnet(critic_observations, actor(data["observations"])) qf1_value = qnet.get_value(F.softmax(qf1, dim=1)) qf2_value = qnet.get_value(F.softmax(qf2, dim=1)) if args.use_cdq: qf_value = torch.minimum(qf1_value, qf2_value) else: qf_value = (qf1_value + qf2_value) / 2.0 actor_loss = -qf_value.mean() actor_optimizer.zero_grad(set_to_none=True) scaler.scale(actor_loss).backward() scaler.unscale_(actor_optimizer) actor_grad_norm = torch.nn.utils.clip_grad_norm_( actor.parameters(), max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"), ) scaler.step(actor_optimizer) scaler.update() actor_scheduler.step() logs_dict["actor_grad_norm"] = actor_grad_norm.detach() logs_dict["actor_loss"] = actor_loss.detach() return logs_dict if args.compile: mode = None update_main = torch.compile(update_main, mode=mode) update_pol = torch.compile(update_pol, mode=mode) policy = torch.compile(policy, mode=mode) normalize_obs = torch.compile(obs_normalizer.forward, mode=mode) normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode) if args.reward_normalization: update_stats = torch.compile(reward_normalizer.update_stats, mode=mode) normalize_reward = torch.compile(reward_normalizer.forward, mode=mode) else: normalize_obs = obs_normalizer.forward normalize_critic_obs = critic_obs_normalizer.forward if args.reward_normalization: update_stats = reward_normalizer.update_stats normalize_reward = reward_normalizer.forward if envs.asymmetric_obs: obs, critic_obs = envs.reset_with_critic_obs() critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float) else: obs = envs.reset() if args.checkpoint_path: # Load checkpoint if specified torch_checkpoint = torch.load( f"{args.checkpoint_path}", map_location=device, weights_only=False ) actor.load_state_dict(torch_checkpoint["actor_state_dict"]) obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"]) critic_obs_normalizer.load_state_dict( torch_checkpoint["critic_obs_normalizer_state"] ) qnet.load_state_dict(torch_checkpoint["qnet_state_dict"]) qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"]) global_step = torch_checkpoint["global_step"] else: global_step = 0 dones = None pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step) start_time = None desc = "" while global_step < args.total_timesteps: logs_dict = TensorDict() if ( start_time is None and global_step >= args.measure_burnin + args.learning_starts ): start_time = time.time() measure_burnin = global_step with ( torch.no_grad(), autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled), ): norm_obs = normalize_obs(obs) actions = policy(obs=norm_obs, dones=dones) next_obs, rewards, dones, _, infos = envs.step(actions.float()) print(infos["time_outs"]) truncations = infos["time_outs"] if args.reward_normalization: if env_type == "mtbench": task_ids_one_hot = obs[..., -envs.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) update_stats(rewards, dones.float(), task_ids=task_indices) else: update_stats(rewards, dones.float()) if envs.asymmetric_obs: next_critic_obs = infos["observations"]["critic"] # Compute 'true' next_obs and next_critic_obs for saving true_next_obs = torch.where( dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs ) if envs.asymmetric_obs: true_next_critic_obs = torch.where( dones[:, None] > 0, infos["observations"]["raw"]["critic_obs"], next_critic_obs, ) transition = TensorDict( { "observations": obs, "actions": torch.as_tensor(actions, device=device, dtype=torch.float), "next": { "observations": true_next_obs, "rewards": torch.as_tensor( rewards, device=device, dtype=torch.float ), "truncations": truncations.long(), "dones": dones.long(), }, }, batch_size=(envs.num_envs,), device=device, ) if envs.asymmetric_obs: transition["critic_observations"] = critic_obs transition["next"]["critic_observations"] = true_next_critic_obs obs = next_obs if envs.asymmetric_obs: critic_obs = next_critic_obs rb.extend(transition) batch_size = args.batch_size // args.num_envs if global_step > args.learning_starts: for i in range(args.num_updates): data = rb.sample(batch_size) data["observations"] = normalize_obs(data["observations"]) data["next"]["observations"] = normalize_obs( data["next"]["observations"] ) raw_rewards = data["next"]["rewards"] if env_type in ["mtbench"] and args.reward_normalization: # Multi-task reward normalization task_ids_one_hot = data["observations"][..., -envs.num_tasks :] task_indices = torch.argmax(task_ids_one_hot, dim=1) data["next"]["rewards"] = normalize_reward( raw_rewards, task_ids=task_indices ) else: data["next"]["rewards"] = normalize_reward(raw_rewards) if envs.asymmetric_obs: data["critic_observations"] = normalize_critic_obs( data["critic_observations"] ) data["next"]["critic_observations"] = normalize_critic_obs( data["next"]["critic_observations"] ) logs_dict = update_main(data, logs_dict) if args.num_updates > 1: if i % args.policy_frequency == 1: logs_dict = update_pol(data, logs_dict) else: if global_step % args.policy_frequency == 0: logs_dict = update_pol(data, logs_dict) for param, target_param in zip( qnet.parameters(), qnet_target.parameters() ): target_param.data.copy_( args.tau * param.data + (1 - args.tau) * target_param.data ) if global_step % 100 == 0 and start_time is not None: speed = (global_step - measure_burnin) / (time.time() - start_time) pbar.set_description(f"{speed: 4.4f} sps, " + desc) with torch.no_grad(): logs = { "actor_loss": logs_dict["actor_loss"].mean(), "qf_loss": logs_dict["qf_loss"].mean(), "qf_max": logs_dict["qf_max"].mean(), "qf_min": logs_dict["qf_min"].mean(), "actor_grad_norm": logs_dict["actor_grad_norm"].mean(), "critic_grad_norm": logs_dict["critic_grad_norm"].mean(), "env_rewards": rewards.mean(), "buffer_rewards": raw_rewards.mean(), } if args.eval_interval > 0 and global_step % args.eval_interval == 0: print(f"Evaluating at global step {global_step}") eval_avg_return, eval_avg_length = evaluate() if env_type in ["humanoid_bench", "isaaclab", "mtbench"]: # NOTE: Hacky way of evaluating performance, but just works obs = envs.reset() logs["eval_avg_return"] = eval_avg_return logs["eval_avg_length"] = eval_avg_length if args.use_wandb: wandb.log( { "speed": speed, "frame": global_step * args.num_envs, "critic_lr": q_scheduler.get_last_lr()[0], "actor_lr": actor_scheduler.get_last_lr()[0], **logs, }, step=global_step, ) if ( args.save_interval > 0 and global_step > 0 and global_step % args.save_interval == 0 ): print(f"Saving model at global step {global_step}") save_params( global_step, actor, qnet, qnet_target, obs_normalizer, critic_obs_normalizer, args, f"models/{run_name}_{global_step}.pt", ) global_step += 1 pbar.update(1) save_params( global_step, actor, qnet, qnet_target, obs_normalizer, critic_obs_normalizer, args, f"models/{run_name}_final.pt", ) if __name__ == "__main__": main()