{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FastTD3 Training Notebook\n", "\n", "Welcome! This notebook will let you execute a series of code blocks that enables you to experience how FastTD3 works -- each block will import packages, define arguments, create environments, create FastTD3 agent, and train the agent.\n", "\n", "This notebook also provide the same functionalities as `train.py` -- you can use this notebook to train your own agents, upload logs to wandb, render rollouts, and fine-tune pre-trained agents with more environment steps!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NOTE: Set environment variables and import packages\n", "\n", "import os\n", "\n", "os.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n", "os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n", "if sys.platform != \"darwin\":\n", " os.environ[\"MUJOCO_GL\"] = \"egl\"\n", "else:\n", " os.environ[\"MUJOCO_GL\"] = \"glfw\"\n", "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n", "os.environ[\"JAX_DEFAULT_MATMUL_PRECISION\"] = \"highest\"\n", "\n", "import random\n", "import time\n", "\n", "import tqdm\n", "import wandb\n", "import numpy as np\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.amp import autocast, GradScaler\n", "from tensordict import TensorDict, from_module\n", "\n", "torch.set_float32_matmul_precision(\"high\")\n", "\n", "from fast_td3_utils import (\n", " EmpiricalNormalization,\n", " SimpleReplayBuffer,\n", " save_params,\n", ")\n", "\n", "from fast_td3 import Critic, Actor" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# NOTE: Set checkpoint if you want to fine-tune from existing checkpoint\n", "# e.g., set checkpoint to \"models/h1-walk-v0_notebook_experiment_30000.pt\"\n", "checkpoint_path = None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# NOTE: Customize arguments as needed\n", "# However, IsaacLab may not work in Notebook Setup.\n", "# We recommend using HumanoidBench or MuJoCo Playground for notebook experiments.\n", "\n", "# For quick experiments, let's use a task without dexterous hands\n", "# But for your research, we recommend using `h1hand` tasks in HumanoidBench!\n", "from hyperparams import HumanoidBenchArgs\n", "\n", "args = HumanoidBenchArgs(\n", " env_name=\"h1-walk-v0\",\n", " total_timesteps=20000,\n", " render_interval=5000,\n", " eval_interval=5000,\n", ")\n", "run_name = f\"{args.env_name}_notebook_experiment\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# NOTE: GPU-Related Configurations\n", "\n", "amp_enabled = args.amp and args.cuda and torch.cuda.is_available()\n", "amp_device_type = (\n", " \"cuda\"\n", " if args.cuda and torch.cuda.is_available()\n", " else \"mps\" if args.cuda and torch.backends.mps.is_available() else \"cpu\"\n", ")\n", "amp_dtype = torch.bfloat16 if args.amp_dtype == \"bf16\" else torch.float16\n", "\n", "scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)\n", "\n", "if not args.cuda:\n", " device = torch.device(\"cpu\")\n", "else:\n", " if torch.cuda.is_available():\n", " device = torch.device(f\"cuda:{args.device_rank}\")\n", " elif torch.backends.mps.is_available():\n", " device = torch.device(f\"mps:{args.device_rank}\")\n", " else:\n", " raise ValueError(\"No GPU available\")\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NOTE: Define Wandb if needed\n", "\n", "# Set use_wandb to True if you want to use Wandb\n", "use_wandb = True\n", "\n", "if use_wandb:\n", " wandb.init(\n", " project=\"FastTD3\",\n", " name=run_name,\n", " config=vars(args),\n", " save_code=True,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NOTE: Initialize Environment and Related Variables\n", "\n", "if args.env_name.startswith(\"h1hand-\") or args.env_name.startswith(\"h1-\"):\n", " from environments.humanoid_bench_env import HumanoidBenchEnv\n", "\n", " env_type = \"humanoid_bench\"\n", " envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n", " eval_envs = envs\n", " render_env = HumanoidBenchEnv(args.env_name, 1, render_mode=\"rgb_array\", device=device)\n", "elif args.env_name.startswith(\"Isaac-\"):\n", " from environments.isaaclab_env import IsaacLabEnv\n", "\n", " env_type = \"isaaclab\"\n", " envs = IsaacLabEnv(\n", " args.env_name,\n", " device.type,\n", " args.num_envs,\n", " args.seed,\n", " action_bounds=args.action_bounds,\n", " )\n", " eval_envs = envs\n", " render_envs = envs\n", "else:\n", " from environments.mujoco_playground_env import make_env\n", " import jax.numpy as jnp\n", "\n", " env_type = \"mujoco_playground\"\n", " envs, eval_envs, render_env = make_env(\n", " args.env_name,\n", " args.seed,\n", " args.num_envs,\n", " args.num_eval_envs,\n", " args.device_rank,\n", " use_tuned_reward=args.use_tuned_reward,\n", " use_domain_randomization=args.use_domain_randomization,\n", " )\n", "\n", "n_act = envs.num_actions\n", "n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]\n", "if envs.asymmetric_obs:\n", " n_critic_obs = (\n", " envs.num_privileged_obs\n", " if type(envs.num_privileged_obs) == int\n", " else envs.num_privileged_obs[0]\n", " )\n", "else:\n", " n_critic_obs = n_obs\n", "action_low, action_high = -1.0, 1.0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# NOTE: Initialize Normalizer, Actor, and Critic\n", "\n", "if args.obs_normalization:\n", " obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)\n", " critic_obs_normalizer = EmpiricalNormalization(shape=n_critic_obs, device=device)\n", "else:\n", " obs_normalizer = nn.Identity()\n", " critic_obs_normalizer = nn.Identity()\n", "\n", "normalize_obs = obs_normalizer.forward\n", "normalize_critic_obs = critic_obs_normalizer.forward\n", "\n", "# Actor setup\n", "actor = Actor(\n", " n_obs=n_obs,\n", " n_act=n_act,\n", " num_envs=args.num_envs,\n", " device=device,\n", " init_scale=args.init_scale,\n", " hidden_dim=args.actor_hidden_dim,\n", ")\n", "actor_detach = Actor(\n", " n_obs=n_obs,\n", " n_act=n_act,\n", " num_envs=args.num_envs,\n", " device=device,\n", " init_scale=args.init_scale,\n", " hidden_dim=args.actor_hidden_dim,\n", ")\n", "# Copy params to actor_detach without grad\n", "from_module(actor).data.to_module(actor_detach)\n", "policy = actor_detach.explore\n", "\n", "qnet = Critic(\n", " n_obs=n_critic_obs,\n", " n_act=n_act,\n", " num_atoms=args.num_atoms,\n", " v_min=args.v_min,\n", " v_max=args.v_max,\n", " hidden_dim=args.critic_hidden_dim,\n", " device=device,\n", ")\n", "qnet_target = Critic(\n", " n_obs=n_critic_obs,\n", " n_act=n_act,\n", " num_atoms=args.num_atoms,\n", " v_min=args.v_min,\n", " v_max=args.v_max,\n", " hidden_dim=args.critic_hidden_dim,\n", " device=device,\n", ")\n", "qnet_target.load_state_dict(qnet.state_dict())\n", "\n", "q_optimizer = optim.AdamW(\n", " list(qnet.parameters()),\n", " lr=args.critic_learning_rate,\n", " weight_decay=args.weight_decay,\n", ")\n", "actor_optimizer = optim.AdamW(\n", " list(actor.parameters()),\n", " lr=args.actor_learning_rate,\n", " weight_decay=args.weight_decay,\n", ")\n", "\n", "rb = SimpleReplayBuffer(\n", " n_env=args.num_envs,\n", " buffer_size=args.buffer_size,\n", " n_obs=n_obs,\n", " n_act=n_act,\n", " n_critic_obs=n_critic_obs,\n", " asymmetric_obs=envs.asymmetric_obs,\n", " n_steps=args.num_steps,\n", " gamma=args.gamma,\n", " device=device,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# NOTE: Define Evaluation & Rendering Functions\n", "\n", "\n", "def evaluate():\n", " obs_normalizer.eval()\n", " num_eval_envs = eval_envs.num_envs\n", " episode_returns = torch.zeros(num_eval_envs, device=device)\n", " episode_lengths = torch.zeros(num_eval_envs, device=device)\n", " done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)\n", "\n", " if env_type == \"isaaclab\":\n", " obs = eval_envs.reset(random_start_init=False)\n", " else:\n", " obs = eval_envs.reset()\n", "\n", " # Run for a fixed number of steps\n", " for _ in range(eval_envs.max_episode_steps):\n", " with torch.no_grad(), autocast(\n", " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", " ):\n", " obs = normalize_obs(obs)\n", " actions = actor(obs)\n", "\n", " next_obs, rewards, dones, _ = eval_envs.step(actions.float())\n", " episode_returns = torch.where(\n", " ~done_masks, episode_returns + rewards, episode_returns\n", " )\n", " episode_lengths = torch.where(~done_masks, episode_lengths + 1, episode_lengths)\n", " done_masks = torch.logical_or(done_masks, dones)\n", " if done_masks.all():\n", " break\n", " obs = next_obs\n", "\n", " obs_normalizer.train()\n", " return episode_returns.mean().item(), episode_lengths.mean().item()\n", "\n", "\n", "def render_with_rollout():\n", " obs_normalizer.eval()\n", "\n", " # Quick rollout for rendering\n", " if env_type == \"humanoid_bench\":\n", " obs = render_env.reset()\n", " renders = [render_env.render()]\n", " elif env_type == \"isaaclab\":\n", " raise NotImplementedError(\n", " \"We don't support rendering for IsaacLab environments\"\n", " )\n", " else:\n", " obs = render_env.reset()\n", " render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n", " renders = [render_env.state]\n", " for i in range(render_env.max_episode_steps):\n", " with torch.no_grad(), autocast(\n", " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", " ):\n", " obs = normalize_obs(obs)\n", " actions = actor(obs)\n", " next_obs, _, done, _ = render_env.step(actions.float())\n", " if env_type == \"mujoco_playground\":\n", " render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n", " if i % 2 == 0:\n", " if env_type == \"humanoid_bench\":\n", " renders.append(render_env.render())\n", " else:\n", " renders.append(render_env.state)\n", " if done.any():\n", " break\n", " obs = next_obs\n", "\n", " if env_type == \"mujoco_playground\":\n", " renders = render_env.render_trajectory(renders)\n", "\n", " obs_normalizer.train()\n", " return renders" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# NOTE: Define Update Functions\n", "\n", "policy_noise = args.policy_noise\n", "noise_clip = args.noise_clip\n", "\n", "\n", "def update_main(data, logs_dict):\n", " with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n", " observations = data[\"observations\"]\n", " next_observations = data[\"next\"][\"observations\"]\n", " if envs.asymmetric_obs:\n", " critic_observations = data[\"critic_observations\"]\n", " next_critic_observations = data[\"next\"][\"critic_observations\"]\n", " else:\n", " critic_observations = observations\n", " next_critic_observations = next_observations\n", " actions = data[\"actions\"]\n", " rewards = data[\"next\"][\"rewards\"]\n", " dones = data[\"next\"][\"dones\"].bool()\n", " truncations = data[\"next\"][\"truncations\"].bool()\n", " if args.disable_bootstrap:\n", " bootstrap = (~dones).float()\n", " else:\n", " bootstrap = (truncations | ~dones).float()\n", "\n", " clipped_noise = torch.randn_like(actions)\n", " clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip)\n", "\n", " next_state_actions = (actor(next_observations) + clipped_noise).clamp(\n", " action_low, action_high\n", " )\n", "\n", " with torch.no_grad():\n", " qf1_next_target_projected, qf2_next_target_projected = (\n", " qnet_target.projection(\n", " next_critic_observations,\n", " next_state_actions,\n", " rewards,\n", " bootstrap,\n", " args.gamma,\n", " )\n", " )\n", " qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)\n", " qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)\n", " if args.use_cdq:\n", " qf_next_target_dist = torch.where(\n", " qf1_next_target_value.unsqueeze(1)\n", " < qf2_next_target_value.unsqueeze(1),\n", " qf1_next_target_projected,\n", " qf2_next_target_projected,\n", " )\n", " qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist\n", " else:\n", " qf1_next_target_dist, qf2_next_target_dist = (\n", " qf1_next_target_projected,\n", " qf2_next_target_projected,\n", " )\n", "\n", " qf1, qf2 = qnet(critic_observations, actions)\n", " qf1_loss = -torch.sum(\n", " qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1\n", " ).mean()\n", " qf2_loss = -torch.sum(\n", " qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1\n", " ).mean()\n", " qf_loss = qf1_loss + qf2_loss\n", "\n", " q_optimizer.zero_grad(set_to_none=True)\n", " scaler.scale(qf_loss).backward()\n", " scaler.unscale_(q_optimizer)\n", "\n", " critic_grad_norm = torch.nn.utils.clip_grad_norm_(\n", " qnet.parameters(),\n", " max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n", " )\n", " scaler.step(q_optimizer)\n", " scaler.update()\n", "\n", " logs_dict[\"buffer_rewards\"] = rewards.mean()\n", " logs_dict[\"critic_grad_norm\"] = critic_grad_norm.detach()\n", " logs_dict[\"qf_loss\"] = qf_loss.detach()\n", " logs_dict[\"qf_max\"] = qf1_next_target_value.max().detach()\n", " logs_dict[\"qf_min\"] = qf1_next_target_value.min().detach()\n", " return logs_dict\n", "\n", "\n", "def update_pol(data, logs_dict):\n", " with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n", " critic_observations = (\n", " data[\"critic_observations\"] if envs.asymmetric_obs else data[\"observations\"]\n", " )\n", "\n", " qf1, qf2 = qnet(critic_observations, actor(data[\"observations\"]))\n", " qf1_value = qnet.get_value(F.softmax(qf1, dim=1))\n", " qf2_value = qnet.get_value(F.softmax(qf2, dim=1))\n", " if args.use_cdq:\n", " qf_value = torch.minimum(qf1_value, qf2_value)\n", " else:\n", " qf_value = (qf1_value + qf2_value) / 2.0\n", " actor_loss = -qf_value.mean()\n", "\n", " actor_optimizer.zero_grad(set_to_none=True)\n", " scaler.scale(actor_loss).backward()\n", " scaler.unscale_(actor_optimizer)\n", " actor_grad_norm = torch.nn.utils.clip_grad_norm_(\n", " actor.parameters(),\n", " max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n", " )\n", " scaler.step(actor_optimizer)\n", " scaler.update()\n", " logs_dict[\"actor_grad_norm\"] = actor_grad_norm.detach()\n", " logs_dict[\"actor_loss\"] = actor_loss.detach()\n", " return logs_dict" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# NOTE: Compile Functions if Needed\n", "\n", "if args.compile:\n", " mode = None\n", " update_main = torch.compile(update_main, mode=mode)\n", " update_pol = torch.compile(update_pol, mode=mode)\n", " policy = torch.compile(policy, mode=mode)\n", " normalize_obs = torch.compile(normalize_obs, mode=mode)\n", " normalize_critic_obs = torch.compile(normalize_critic_obs, mode=mode)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# NOTE: Load Checkpoint if Needed\n", "if checkpoint_path is not None:\n", " torch_checkpoint = torch.load(\n", " f\"{checkpoint_path}\", map_location=device, weights_only=False\n", " )\n", "\n", " actor.load_state_dict(torch_checkpoint[\"actor_state_dict\"])\n", " obs_normalizer.load_state_dict(torch_checkpoint[\"obs_normalizer_state\"])\n", " critic_obs_normalizer.load_state_dict(\n", " torch_checkpoint[\"critic_obs_normalizer_state\"]\n", " )\n", " qnet.load_state_dict(torch_checkpoint[\"qnet_state_dict\"])\n", " qnet_target.load_state_dict(torch_checkpoint[\"qnet_target_state_dict\"])\n", " global_step = torch_checkpoint[\"global_step\"]\n", "else:\n", " global_step = 0" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# NOTE: Utility functions for displaying videos in notebook\n", "\n", "from IPython.display import display, HTML\n", "import base64\n", "import imageio\n", "import tempfile\n", "import os\n", "\n", "\n", "def frames_to_video_html(frames, fps=30):\n", " \"\"\"\n", " Convert a list of numpy arrays to an HTML5 video element.\n", "\n", " Args:\n", " frames (list): List of numpy arrays representing video frames\n", " fps (int): Frames per second for the video\n", "\n", " Returns:\n", " HTML object containing the video element\n", " \"\"\"\n", " # Create a temporary file to store the video\n", " with tempfile.NamedTemporaryFile(suffix=\".mp4\", delete=False) as temp_file:\n", " temp_filename = temp_file.name\n", "\n", " # Save frames as video\n", " imageio.mimsave(temp_filename, frames, fps=fps)\n", "\n", " # Read the video file and encode it to base64\n", " with open(temp_filename, \"rb\") as f:\n", " video_data = f.read()\n", " video_b64 = base64.b64encode(video_data).decode(\"utf-8\")\n", "\n", " # Create HTML video element\n", " video_html = f\"\"\"\n", " \n", " \"\"\"\n", "\n", " # Clean up the temporary file\n", " os.unlink(temp_filename)\n", "\n", " return HTML(video_html)\n", "\n", "\n", "def update_video_display(frames, fps=30):\n", " \"\"\"\n", " Display video frames as an embedded HTML5 video element.\n", "\n", " Args:\n", " frames (list): List of numpy arrays representing video frames\n", " fps (int): Frames per second for the video\n", " \"\"\"\n", " video_html = frames_to_video_html(frames, fps=fps)\n", " display(video_html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# NOTE: Main Training Loop\n", "\n", "if envs.asymmetric_obs:\n", " obs, critic_obs = envs.reset_with_critic_obs()\n", " critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)\n", "else:\n", " obs = envs.reset()\n", "pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)\n", "\n", "dones = None\n", "while global_step < args.total_timesteps:\n", " logs_dict = TensorDict()\n", " with torch.no_grad(), autocast(\n", " device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n", " ):\n", " norm_obs = normalize_obs(obs)\n", " actions = policy(obs=norm_obs, dones=dones)\n", "\n", " next_obs, rewards, dones, infos = envs.step(actions.float())\n", " truncations = infos[\"time_outs\"]\n", "\n", " if envs.asymmetric_obs:\n", " next_critic_obs = infos[\"observations\"][\"critic\"]\n", "\n", " # Compute 'true' next_obs and next_critic_obs for saving\n", " true_next_obs = torch.where(\n", " dones[:, None] > 0, infos[\"observations\"][\"raw\"][\"obs\"], next_obs\n", " )\n", " if envs.asymmetric_obs:\n", " true_next_critic_obs = torch.where(\n", " dones[:, None] > 0,\n", " infos[\"observations\"][\"raw\"][\"critic_obs\"],\n", " next_critic_obs,\n", " )\n", " transition = TensorDict(\n", " {\n", " \"observations\": obs,\n", " \"actions\": torch.as_tensor(actions, device=device, dtype=torch.float),\n", " \"next\": {\n", " \"observations\": true_next_obs,\n", " \"rewards\": torch.as_tensor(rewards, device=device, dtype=torch.float),\n", " \"truncations\": truncations.long(),\n", " \"dones\": dones.long(),\n", " },\n", " },\n", " batch_size=(envs.num_envs,),\n", " device=device,\n", " )\n", " if envs.asymmetric_obs:\n", " transition[\"critic_observations\"] = critic_obs\n", " transition[\"next\"][\"critic_observations\"] = true_next_critic_obs\n", "\n", " obs = next_obs\n", " if envs.asymmetric_obs:\n", " critic_obs = next_critic_obs\n", "\n", " rb.extend(transition)\n", "\n", " batch_size = args.batch_size // args.num_envs\n", " if global_step > args.learning_starts:\n", " for i in range(args.num_updates):\n", " data = rb.sample(batch_size)\n", " data[\"observations\"] = normalize_obs(data[\"observations\"])\n", " data[\"next\"][\"observations\"] = normalize_obs(data[\"next\"][\"observations\"])\n", " if envs.asymmetric_obs:\n", " data[\"critic_observations\"] = normalize_critic_obs(\n", " data[\"critic_observations\"]\n", " )\n", " data[\"next\"][\"critic_observations\"] = normalize_critic_obs(\n", " data[\"next\"][\"critic_observations\"]\n", " )\n", " logs_dict = update_main(data, logs_dict)\n", " if args.num_updates > 1:\n", " if i % args.policy_frequency == 1:\n", " logs_dict = update_pol(data, logs_dict)\n", " else:\n", " if global_step % args.policy_frequency == 0:\n", " logs_dict = update_pol(data, logs_dict)\n", "\n", " for param, target_param in zip(qnet.parameters(), qnet_target.parameters()):\n", " target_param.data.copy_(\n", " args.tau * param.data + (1 - args.tau) * target_param.data\n", " )\n", "\n", " if global_step > 0 and global_step % 100 == 0:\n", " with torch.no_grad():\n", " logs = {\n", " \"actor_loss\": logs_dict[\"actor_loss\"].mean(),\n", " \"qf_loss\": logs_dict[\"qf_loss\"].mean(),\n", " \"qf_max\": logs_dict[\"qf_max\"].mean(),\n", " \"qf_min\": logs_dict[\"qf_min\"].mean(),\n", " \"actor_grad_norm\": logs_dict[\"actor_grad_norm\"].mean(),\n", " \"critic_grad_norm\": logs_dict[\"critic_grad_norm\"].mean(),\n", " \"buffer_rewards\": logs_dict[\"buffer_rewards\"].mean(),\n", " \"env_rewards\": rewards.mean(),\n", " }\n", "\n", " if args.eval_interval > 0 and global_step % args.eval_interval == 0:\n", " eval_avg_return, eval_avg_length = evaluate()\n", " if env_type in [\"humanoid_bench\", \"isaaclab\"]:\n", " # NOTE: Hacky way of evaluating performance, but just works\n", " obs = envs.reset()\n", " logs[\"eval_avg_return\"] = eval_avg_return\n", " logs[\"eval_avg_length\"] = eval_avg_length\n", "\n", " if args.render_interval > 0 and global_step % args.render_interval == 0:\n", " renders = render_with_rollout()\n", " print_logs = {\n", " k: v.item() if isinstance(v, torch.Tensor) else v\n", " for k, v in logs.items()\n", " }\n", " for k, v in print_logs.items():\n", " print(f\"{k}: {v:.4f}\")\n", " update_video_display(renders, fps=30)\n", " if use_wandb:\n", " wandb.log(\n", " {\n", " \"render_video\": wandb.Video(\n", " np.array(renders).transpose(\n", " 0, 3, 1, 2\n", " ), # Convert to (T, C, H, W) format\n", " fps=30,\n", " format=\"gif\",\n", " )\n", " },\n", " step=global_step,\n", " )\n", " if use_wandb:\n", " wandb.log(\n", " {\n", " \"frame\": global_step * args.num_envs,\n", " **logs,\n", " },\n", " step=global_step,\n", " )\n", "\n", " if (\n", " args.save_interval > 0\n", " and global_step > 0\n", " and global_step % args.save_interval == 0\n", " ):\n", " save_params(\n", " global_step,\n", " actor,\n", " qnet,\n", " qnet_target,\n", " obs_normalizer,\n", " critic_obs_normalizer,\n", " args,\n", " f\"models/{run_name}_{global_step}.pt\",\n", " )\n", "\n", " global_step += 1\n", " pbar.update(1)\n", "\n", "save_params(\n", " global_step,\n", " actor,\n", " qnet,\n", " qnet_target,\n", " obs_normalizer,\n", " critic_obs_normalizer,\n", " args,\n", " f\"models/{run_name}_final.pt\",\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "fasttd3_hb", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.17" } }, "nbformat": 4, "nbformat_minor": 2 }