This PR includes these changes: - Fixing a bug in MTBench evaluation - Add a missing `critic_cls` in `train.py` (resolving an issue https://github.com/younggyoseo/FastTD3/issues/17) - Updating hyperparameters for MTBench
789 lines
30 KiB
Plaintext
789 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# FastTD3 Training Notebook\n",
|
|
"\n",
|
|
"Welcome! This notebook will let you execute a series of code blocks that enables you to experience how FastTD3 works -- each block will import packages, define arguments, create environments, create FastTD3 agent, and train the agent.\n",
|
|
"\n",
|
|
"This notebook also provide the same functionalities as `train.py` -- you can use this notebook to train your own agents, upload logs to wandb, render rollouts, and fine-tune pre-trained agents with more environment steps!"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Set environment variables and import packages\n",
|
|
"\n",
|
|
"import os\n",
|
|
"\n",
|
|
"os.environ[\"TORCHDYNAMO_INLINE_INBUILT_NN_MODULES\"] = \"1\"\n",
|
|
"os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n",
|
|
"if sys.platform != \"darwin\":\n",
|
|
" os.environ[\"MUJOCO_GL\"] = \"egl\"\n",
|
|
"else:\n",
|
|
" os.environ[\"MUJOCO_GL\"] = \"glfw\"\n",
|
|
"os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n",
|
|
"os.environ[\"JAX_DEFAULT_MATMUL_PRECISION\"] = \"highest\"\n",
|
|
"\n",
|
|
"import random\n",
|
|
"import time\n",
|
|
"\n",
|
|
"import tqdm\n",
|
|
"import wandb\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"from torch.amp import autocast, GradScaler\n",
|
|
"from tensordict import TensorDict, from_module\n",
|
|
"\n",
|
|
"torch.set_float32_matmul_precision(\"high\")\n",
|
|
"\n",
|
|
"from fast_td3_utils import (\n",
|
|
" EmpiricalNormalization,\n",
|
|
" SimpleReplayBuffer,\n",
|
|
" save_params,\n",
|
|
")\n",
|
|
"\n",
|
|
"from fast_td3 import Critic, Actor"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Set checkpoint if you want to fine-tune from existing checkpoint\n",
|
|
"# e.g., set checkpoint to \"models/h1-walk-v0_notebook_experiment_30000.pt\"\n",
|
|
"checkpoint_path = None"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Customize arguments as needed\n",
|
|
"# However, IsaacLab may not work in Notebook Setup.\n",
|
|
"# We recommend using HumanoidBench or MuJoCo Playground for notebook experiments.\n",
|
|
"\n",
|
|
"# For quick experiments, let's use a task without dexterous hands\n",
|
|
"# But for your research, we recommend using `h1hand` tasks in HumanoidBench!\n",
|
|
"from hyperparams import HumanoidBenchArgs\n",
|
|
"\n",
|
|
"args = HumanoidBenchArgs(\n",
|
|
" env_name=\"h1-walk-v0\",\n",
|
|
" total_timesteps=20000,\n",
|
|
" render_interval=5000,\n",
|
|
" eval_interval=5000,\n",
|
|
")\n",
|
|
"run_name = f\"{args.env_name}_notebook_experiment\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: GPU-Related Configurations\n",
|
|
"\n",
|
|
"amp_enabled = args.amp and args.cuda and torch.cuda.is_available()\n",
|
|
"amp_device_type = (\n",
|
|
" \"cuda\"\n",
|
|
" if args.cuda and torch.cuda.is_available()\n",
|
|
" else \"mps\" if args.cuda and torch.backends.mps.is_available() else \"cpu\"\n",
|
|
")\n",
|
|
"amp_dtype = torch.bfloat16 if args.amp_dtype == \"bf16\" else torch.float16\n",
|
|
"\n",
|
|
"scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)\n",
|
|
"\n",
|
|
"if not args.cuda:\n",
|
|
" device = torch.device(\"cpu\")\n",
|
|
"else:\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" device = torch.device(f\"cuda:{args.device_rank}\")\n",
|
|
" elif torch.backends.mps.is_available():\n",
|
|
" device = torch.device(f\"mps:{args.device_rank}\")\n",
|
|
" else:\n",
|
|
" raise ValueError(\"No GPU available\")\n",
|
|
"print(f\"Using device: {device}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Define Wandb if needed\n",
|
|
"\n",
|
|
"# Set use_wandb to True if you want to use Wandb\n",
|
|
"use_wandb = True\n",
|
|
"\n",
|
|
"if use_wandb:\n",
|
|
" wandb.init(\n",
|
|
" project=\"FastTD3\",\n",
|
|
" name=run_name,\n",
|
|
" config=vars(args),\n",
|
|
" save_code=True,\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Initialize Environment and Related Variables\n",
|
|
"\n",
|
|
"if args.env_name.startswith(\"h1hand-\") or args.env_name.startswith(\"h1-\"):\n",
|
|
" from environments.humanoid_bench_env import HumanoidBenchEnv\n",
|
|
"\n",
|
|
" env_type = \"humanoid_bench\"\n",
|
|
" envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)\n",
|
|
" eval_envs = envs\n",
|
|
" render_env = HumanoidBenchEnv(\n",
|
|
" args.env_name, 1, render_mode=\"rgb_array\", device=device\n",
|
|
" )\n",
|
|
"elif args.env_name.startswith(\"Isaac-\"):\n",
|
|
" from environments.isaaclab_env import IsaacLabEnv\n",
|
|
"\n",
|
|
" env_type = \"isaaclab\"\n",
|
|
" envs = IsaacLabEnv(\n",
|
|
" args.env_name,\n",
|
|
" device.type,\n",
|
|
" args.num_envs,\n",
|
|
" args.seed,\n",
|
|
" action_bounds=args.action_bounds,\n",
|
|
" )\n",
|
|
" eval_envs = envs\n",
|
|
" render_envs = envs\n",
|
|
"else:\n",
|
|
" from environments.mujoco_playground_env import make_env\n",
|
|
" import jax.numpy as jnp\n",
|
|
"\n",
|
|
" env_type = \"mujoco_playground\"\n",
|
|
" envs, eval_envs, render_env = make_env(\n",
|
|
" args.env_name,\n",
|
|
" args.seed,\n",
|
|
" args.num_envs,\n",
|
|
" args.num_eval_envs,\n",
|
|
" args.device_rank,\n",
|
|
" use_tuned_reward=args.use_tuned_reward,\n",
|
|
" use_domain_randomization=args.use_domain_randomization,\n",
|
|
" )\n",
|
|
"\n",
|
|
"n_act = envs.num_actions\n",
|
|
"n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]\n",
|
|
"if envs.asymmetric_obs:\n",
|
|
" n_critic_obs = (\n",
|
|
" envs.num_privileged_obs\n",
|
|
" if type(envs.num_privileged_obs) == int\n",
|
|
" else envs.num_privileged_obs[0]\n",
|
|
" )\n",
|
|
"else:\n",
|
|
" n_critic_obs = n_obs\n",
|
|
"action_low, action_high = -1.0, 1.0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Initialize Normalizer, Actor, and Critic\n",
|
|
"\n",
|
|
"if args.obs_normalization:\n",
|
|
" obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)\n",
|
|
" critic_obs_normalizer = EmpiricalNormalization(shape=n_critic_obs, device=device)\n",
|
|
"else:\n",
|
|
" obs_normalizer = nn.Identity()\n",
|
|
" critic_obs_normalizer = nn.Identity()\n",
|
|
"\n",
|
|
"normalize_obs = obs_normalizer.forward\n",
|
|
"normalize_critic_obs = critic_obs_normalizer.forward\n",
|
|
"\n",
|
|
"# Actor setup\n",
|
|
"actor = Actor(\n",
|
|
" n_obs=n_obs,\n",
|
|
" n_act=n_act,\n",
|
|
" num_envs=args.num_envs,\n",
|
|
" device=device,\n",
|
|
" init_scale=args.init_scale,\n",
|
|
" hidden_dim=args.actor_hidden_dim,\n",
|
|
")\n",
|
|
"actor_detach = Actor(\n",
|
|
" n_obs=n_obs,\n",
|
|
" n_act=n_act,\n",
|
|
" num_envs=args.num_envs,\n",
|
|
" device=device,\n",
|
|
" init_scale=args.init_scale,\n",
|
|
" hidden_dim=args.actor_hidden_dim,\n",
|
|
")\n",
|
|
"# Copy params to actor_detach without grad\n",
|
|
"from_module(actor).data.to_module(actor_detach)\n",
|
|
"policy = actor_detach.explore\n",
|
|
"\n",
|
|
"qnet = Critic(\n",
|
|
" n_obs=n_critic_obs,\n",
|
|
" n_act=n_act,\n",
|
|
" num_atoms=args.num_atoms,\n",
|
|
" v_min=args.v_min,\n",
|
|
" v_max=args.v_max,\n",
|
|
" hidden_dim=args.critic_hidden_dim,\n",
|
|
" device=device,\n",
|
|
")\n",
|
|
"qnet_target = Critic(\n",
|
|
" n_obs=n_critic_obs,\n",
|
|
" n_act=n_act,\n",
|
|
" num_atoms=args.num_atoms,\n",
|
|
" v_min=args.v_min,\n",
|
|
" v_max=args.v_max,\n",
|
|
" hidden_dim=args.critic_hidden_dim,\n",
|
|
" device=device,\n",
|
|
")\n",
|
|
"qnet_target.load_state_dict(qnet.state_dict())\n",
|
|
"\n",
|
|
"q_optimizer = optim.AdamW(\n",
|
|
" list(qnet.parameters()),\n",
|
|
" lr=args.critic_learning_rate,\n",
|
|
" weight_decay=args.weight_decay,\n",
|
|
")\n",
|
|
"actor_optimizer = optim.AdamW(\n",
|
|
" list(actor.parameters()),\n",
|
|
" lr=args.actor_learning_rate,\n",
|
|
" weight_decay=args.weight_decay,\n",
|
|
")\n",
|
|
"\n",
|
|
"rb = SimpleReplayBuffer(\n",
|
|
" n_env=args.num_envs,\n",
|
|
" buffer_size=args.buffer_size,\n",
|
|
" n_obs=n_obs,\n",
|
|
" n_act=n_act,\n",
|
|
" n_critic_obs=n_critic_obs,\n",
|
|
" asymmetric_obs=envs.asymmetric_obs,\n",
|
|
" playground_mode=env_type == \"mujoco_playground\",\n",
|
|
" n_steps=args.num_steps,\n",
|
|
" gamma=args.gamma,\n",
|
|
" device=device,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Define Evaluation & Rendering Functions\n",
|
|
"\n",
|
|
"\n",
|
|
"def evaluate():\n",
|
|
" obs_normalizer.eval()\n",
|
|
" num_eval_envs = eval_envs.num_envs\n",
|
|
" episode_returns = torch.zeros(num_eval_envs, device=device)\n",
|
|
" episode_lengths = torch.zeros(num_eval_envs, device=device)\n",
|
|
" done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)\n",
|
|
"\n",
|
|
" if env_type == \"isaaclab\":\n",
|
|
" obs = eval_envs.reset(random_start_init=False)\n",
|
|
" else:\n",
|
|
" obs = eval_envs.reset()\n",
|
|
"\n",
|
|
" # Run for a fixed number of steps\n",
|
|
" for _ in range(eval_envs.max_episode_steps):\n",
|
|
" with torch.no_grad(), autocast(\n",
|
|
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
|
|
" ):\n",
|
|
" obs = normalize_obs(obs)\n",
|
|
" actions = actor(obs)\n",
|
|
"\n",
|
|
" next_obs, rewards, dones, _ = eval_envs.step(actions.float())\n",
|
|
" episode_returns = torch.where(\n",
|
|
" ~done_masks, episode_returns + rewards, episode_returns\n",
|
|
" )\n",
|
|
" episode_lengths = torch.where(~done_masks, episode_lengths + 1, episode_lengths)\n",
|
|
" done_masks = torch.logical_or(done_masks, dones)\n",
|
|
" if done_masks.all():\n",
|
|
" break\n",
|
|
" obs = next_obs\n",
|
|
"\n",
|
|
" obs_normalizer.train()\n",
|
|
" return episode_returns.mean().item(), episode_lengths.mean().item()\n",
|
|
"\n",
|
|
"\n",
|
|
"def render_with_rollout():\n",
|
|
" obs_normalizer.eval()\n",
|
|
"\n",
|
|
" # Quick rollout for rendering\n",
|
|
" if env_type == \"humanoid_bench\":\n",
|
|
" obs = render_env.reset()\n",
|
|
" renders = [render_env.render()]\n",
|
|
" elif env_type == \"isaaclab\":\n",
|
|
" raise NotImplementedError(\n",
|
|
" \"We don't support rendering for IsaacLab environments\"\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" obs = render_env.reset()\n",
|
|
" render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n",
|
|
" renders = [render_env.state]\n",
|
|
" for i in range(render_env.max_episode_steps):\n",
|
|
" with torch.no_grad(), autocast(\n",
|
|
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
|
|
" ):\n",
|
|
" obs = normalize_obs(obs)\n",
|
|
" actions = actor(obs)\n",
|
|
" next_obs, _, done, _ = render_env.step(actions.float())\n",
|
|
" if env_type == \"mujoco_playground\":\n",
|
|
" render_env.state.info[\"command\"] = jnp.array([[1.0, 0.0, 0.0]])\n",
|
|
" if i % 2 == 0:\n",
|
|
" if env_type == \"humanoid_bench\":\n",
|
|
" renders.append(render_env.render())\n",
|
|
" else:\n",
|
|
" renders.append(render_env.state)\n",
|
|
" if done.any():\n",
|
|
" break\n",
|
|
" obs = next_obs\n",
|
|
"\n",
|
|
" if env_type == \"mujoco_playground\":\n",
|
|
" renders = render_env.render_trajectory(renders)\n",
|
|
"\n",
|
|
" obs_normalizer.train()\n",
|
|
" return renders"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Define Update Functions\n",
|
|
"\n",
|
|
"policy_noise = args.policy_noise\n",
|
|
"noise_clip = args.noise_clip\n",
|
|
"\n",
|
|
"\n",
|
|
"def update_main(data, logs_dict):\n",
|
|
" with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n",
|
|
" observations = data[\"observations\"]\n",
|
|
" next_observations = data[\"next\"][\"observations\"]\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" critic_observations = data[\"critic_observations\"]\n",
|
|
" next_critic_observations = data[\"next\"][\"critic_observations\"]\n",
|
|
" else:\n",
|
|
" critic_observations = observations\n",
|
|
" next_critic_observations = next_observations\n",
|
|
" actions = data[\"actions\"]\n",
|
|
" rewards = data[\"next\"][\"rewards\"]\n",
|
|
" dones = data[\"next\"][\"dones\"].bool()\n",
|
|
" truncations = data[\"next\"][\"truncations\"].bool()\n",
|
|
" if args.disable_bootstrap:\n",
|
|
" bootstrap = (~dones).float()\n",
|
|
" else:\n",
|
|
" bootstrap = (truncations | ~dones).float()\n",
|
|
"\n",
|
|
" clipped_noise = torch.randn_like(actions)\n",
|
|
" clipped_noise = clipped_noise.mul(policy_noise).clamp(-noise_clip, noise_clip)\n",
|
|
"\n",
|
|
" next_state_actions = (actor(next_observations) + clipped_noise).clamp(\n",
|
|
" action_low, action_high\n",
|
|
" )\n",
|
|
"\n",
|
|
" with torch.no_grad():\n",
|
|
" qf1_next_target_projected, qf2_next_target_projected = (\n",
|
|
" qnet_target.projection(\n",
|
|
" next_critic_observations,\n",
|
|
" next_state_actions,\n",
|
|
" rewards,\n",
|
|
" bootstrap,\n",
|
|
" args.gamma,\n",
|
|
" )\n",
|
|
" )\n",
|
|
" qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)\n",
|
|
" qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)\n",
|
|
" if args.use_cdq:\n",
|
|
" qf_next_target_dist = torch.where(\n",
|
|
" qf1_next_target_value.unsqueeze(1)\n",
|
|
" < qf2_next_target_value.unsqueeze(1),\n",
|
|
" qf1_next_target_projected,\n",
|
|
" qf2_next_target_projected,\n",
|
|
" )\n",
|
|
" qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist\n",
|
|
" else:\n",
|
|
" qf1_next_target_dist, qf2_next_target_dist = (\n",
|
|
" qf1_next_target_projected,\n",
|
|
" qf2_next_target_projected,\n",
|
|
" )\n",
|
|
"\n",
|
|
" qf1, qf2 = qnet(critic_observations, actions)\n",
|
|
" qf1_loss = -torch.sum(\n",
|
|
" qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1\n",
|
|
" ).mean()\n",
|
|
" qf2_loss = -torch.sum(\n",
|
|
" qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1\n",
|
|
" ).mean()\n",
|
|
" qf_loss = qf1_loss + qf2_loss\n",
|
|
"\n",
|
|
" q_optimizer.zero_grad(set_to_none=True)\n",
|
|
" scaler.scale(qf_loss).backward()\n",
|
|
" scaler.unscale_(q_optimizer)\n",
|
|
"\n",
|
|
" critic_grad_norm = torch.nn.utils.clip_grad_norm_(\n",
|
|
" qnet.parameters(),\n",
|
|
" max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n",
|
|
" )\n",
|
|
" scaler.step(q_optimizer)\n",
|
|
" scaler.update()\n",
|
|
"\n",
|
|
" logs_dict[\"buffer_rewards\"] = rewards.mean()\n",
|
|
" logs_dict[\"critic_grad_norm\"] = critic_grad_norm.detach()\n",
|
|
" logs_dict[\"qf_loss\"] = qf_loss.detach()\n",
|
|
" logs_dict[\"qf_max\"] = qf1_next_target_value.max().detach()\n",
|
|
" logs_dict[\"qf_min\"] = qf1_next_target_value.min().detach()\n",
|
|
" return logs_dict\n",
|
|
"\n",
|
|
"\n",
|
|
"def update_pol(data, logs_dict):\n",
|
|
" with autocast(device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled):\n",
|
|
" critic_observations = (\n",
|
|
" data[\"critic_observations\"] if envs.asymmetric_obs else data[\"observations\"]\n",
|
|
" )\n",
|
|
"\n",
|
|
" qf1, qf2 = qnet(critic_observations, actor(data[\"observations\"]))\n",
|
|
" qf1_value = qnet.get_value(F.softmax(qf1, dim=1))\n",
|
|
" qf2_value = qnet.get_value(F.softmax(qf2, dim=1))\n",
|
|
" if args.use_cdq:\n",
|
|
" qf_value = torch.minimum(qf1_value, qf2_value)\n",
|
|
" else:\n",
|
|
" qf_value = (qf1_value + qf2_value) / 2.0\n",
|
|
" actor_loss = -qf_value.mean()\n",
|
|
"\n",
|
|
" actor_optimizer.zero_grad(set_to_none=True)\n",
|
|
" scaler.scale(actor_loss).backward()\n",
|
|
" scaler.unscale_(actor_optimizer)\n",
|
|
" actor_grad_norm = torch.nn.utils.clip_grad_norm_(\n",
|
|
" actor.parameters(),\n",
|
|
" max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float(\"inf\"),\n",
|
|
" )\n",
|
|
" scaler.step(actor_optimizer)\n",
|
|
" scaler.update()\n",
|
|
" logs_dict[\"actor_grad_norm\"] = actor_grad_norm.detach()\n",
|
|
" logs_dict[\"actor_loss\"] = actor_loss.detach()\n",
|
|
" return logs_dict"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Compile Functions if Needed\n",
|
|
"\n",
|
|
"if args.compile:\n",
|
|
" mode = None\n",
|
|
" update_main = torch.compile(update_main, mode=mode)\n",
|
|
" update_pol = torch.compile(update_pol, mode=mode)\n",
|
|
" policy = torch.compile(policy, mode=mode)\n",
|
|
" normalize_obs = torch.compile(normalize_obs, mode=mode)\n",
|
|
" normalize_critic_obs = torch.compile(normalize_critic_obs, mode=mode)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Load Checkpoint if Needed\n",
|
|
"if checkpoint_path is not None:\n",
|
|
" torch_checkpoint = torch.load(\n",
|
|
" f\"{checkpoint_path}\", map_location=device, weights_only=False\n",
|
|
" )\n",
|
|
"\n",
|
|
" actor.load_state_dict(torch_checkpoint[\"actor_state_dict\"])\n",
|
|
" obs_normalizer.load_state_dict(torch_checkpoint[\"obs_normalizer_state\"])\n",
|
|
" critic_obs_normalizer.load_state_dict(\n",
|
|
" torch_checkpoint[\"critic_obs_normalizer_state\"]\n",
|
|
" )\n",
|
|
" qnet.load_state_dict(torch_checkpoint[\"qnet_state_dict\"])\n",
|
|
" qnet_target.load_state_dict(torch_checkpoint[\"qnet_target_state_dict\"])\n",
|
|
" global_step = torch_checkpoint[\"global_step\"]\n",
|
|
"else:\n",
|
|
" global_step = 0"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Utility functions for displaying videos in notebook\n",
|
|
"\n",
|
|
"from IPython.display import display, HTML\n",
|
|
"import base64\n",
|
|
"import imageio\n",
|
|
"import tempfile\n",
|
|
"import os\n",
|
|
"\n",
|
|
"\n",
|
|
"def frames_to_video_html(frames, fps=30):\n",
|
|
" \"\"\"\n",
|
|
" Convert a list of numpy arrays to an HTML5 video element.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" frames (list): List of numpy arrays representing video frames\n",
|
|
" fps (int): Frames per second for the video\n",
|
|
"\n",
|
|
" Returns:\n",
|
|
" HTML object containing the video element\n",
|
|
" \"\"\"\n",
|
|
" # Create a temporary file to store the video\n",
|
|
" with tempfile.NamedTemporaryFile(suffix=\".mp4\", delete=False) as temp_file:\n",
|
|
" temp_filename = temp_file.name\n",
|
|
"\n",
|
|
" # Save frames as video\n",
|
|
" imageio.mimsave(temp_filename, frames, fps=fps)\n",
|
|
"\n",
|
|
" # Read the video file and encode it to base64\n",
|
|
" with open(temp_filename, \"rb\") as f:\n",
|
|
" video_data = f.read()\n",
|
|
" video_b64 = base64.b64encode(video_data).decode(\"utf-8\")\n",
|
|
"\n",
|
|
" # Create HTML video element\n",
|
|
" video_html = f\"\"\"\n",
|
|
" <video width=\"640\" height=\"480\" controls>\n",
|
|
" <source src=\"data:video/mp4;base64,{video_b64}\" type=\"video/mp4\">\n",
|
|
" Your browser does not support the video tag.\n",
|
|
" </video>\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" # Clean up the temporary file\n",
|
|
" os.unlink(temp_filename)\n",
|
|
"\n",
|
|
" return HTML(video_html)\n",
|
|
"\n",
|
|
"\n",
|
|
"def update_video_display(frames, fps=30):\n",
|
|
" \"\"\"\n",
|
|
" Display video frames as an embedded HTML5 video element.\n",
|
|
"\n",
|
|
" Args:\n",
|
|
" frames (list): List of numpy arrays representing video frames\n",
|
|
" fps (int): Frames per second for the video\n",
|
|
" \"\"\"\n",
|
|
" video_html = frames_to_video_html(frames, fps=fps)\n",
|
|
" display(video_html)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# NOTE: Main Training Loop\n",
|
|
"\n",
|
|
"if envs.asymmetric_obs:\n",
|
|
" obs, critic_obs = envs.reset_with_critic_obs()\n",
|
|
" critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)\n",
|
|
"else:\n",
|
|
" obs = envs.reset()\n",
|
|
"pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)\n",
|
|
"\n",
|
|
"dones = None\n",
|
|
"while global_step < args.total_timesteps:\n",
|
|
" logs_dict = TensorDict()\n",
|
|
" with torch.no_grad(), autocast(\n",
|
|
" device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled\n",
|
|
" ):\n",
|
|
" norm_obs = normalize_obs(obs)\n",
|
|
" actions = policy(obs=norm_obs, dones=dones)\n",
|
|
"\n",
|
|
" next_obs, rewards, dones, infos = envs.step(actions.float())\n",
|
|
" truncations = infos[\"time_outs\"]\n",
|
|
"\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" next_critic_obs = infos[\"observations\"][\"critic\"]\n",
|
|
"\n",
|
|
" # Compute 'true' next_obs and next_critic_obs for saving\n",
|
|
" true_next_obs = torch.where(\n",
|
|
" dones[:, None] > 0, infos[\"observations\"][\"raw\"][\"obs\"], next_obs\n",
|
|
" )\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" true_next_critic_obs = torch.where(\n",
|
|
" dones[:, None] > 0,\n",
|
|
" infos[\"observations\"][\"raw\"][\"critic_obs\"],\n",
|
|
" next_critic_obs,\n",
|
|
" )\n",
|
|
" transition = TensorDict(\n",
|
|
" {\n",
|
|
" \"observations\": obs,\n",
|
|
" \"actions\": torch.as_tensor(actions, device=device, dtype=torch.float),\n",
|
|
" \"next\": {\n",
|
|
" \"observations\": true_next_obs,\n",
|
|
" \"rewards\": torch.as_tensor(rewards, device=device, dtype=torch.float),\n",
|
|
" \"truncations\": truncations.long(),\n",
|
|
" \"dones\": dones.long(),\n",
|
|
" },\n",
|
|
" },\n",
|
|
" batch_size=(envs.num_envs,),\n",
|
|
" device=device,\n",
|
|
" )\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" transition[\"critic_observations\"] = critic_obs\n",
|
|
" transition[\"next\"][\"critic_observations\"] = true_next_critic_obs\n",
|
|
"\n",
|
|
" obs = next_obs\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" critic_obs = next_critic_obs\n",
|
|
"\n",
|
|
" rb.extend(transition)\n",
|
|
"\n",
|
|
" batch_size = args.batch_size // args.num_envs\n",
|
|
" if global_step > args.learning_starts:\n",
|
|
" for i in range(args.num_updates):\n",
|
|
" data = rb.sample(batch_size)\n",
|
|
" data[\"observations\"] = normalize_obs(data[\"observations\"])\n",
|
|
" data[\"next\"][\"observations\"] = normalize_obs(data[\"next\"][\"observations\"])\n",
|
|
" if envs.asymmetric_obs:\n",
|
|
" data[\"critic_observations\"] = normalize_critic_obs(\n",
|
|
" data[\"critic_observations\"]\n",
|
|
" )\n",
|
|
" data[\"next\"][\"critic_observations\"] = normalize_critic_obs(\n",
|
|
" data[\"next\"][\"critic_observations\"]\n",
|
|
" )\n",
|
|
" logs_dict = update_main(data, logs_dict)\n",
|
|
" if args.num_updates > 1:\n",
|
|
" if i % args.policy_frequency == 1:\n",
|
|
" logs_dict = update_pol(data, logs_dict)\n",
|
|
" else:\n",
|
|
" if global_step % args.policy_frequency == 0:\n",
|
|
" logs_dict = update_pol(data, logs_dict)\n",
|
|
"\n",
|
|
" for param, target_param in zip(qnet.parameters(), qnet_target.parameters()):\n",
|
|
" target_param.data.copy_(\n",
|
|
" args.tau * param.data + (1 - args.tau) * target_param.data\n",
|
|
" )\n",
|
|
"\n",
|
|
" if global_step > 0 and global_step % 100 == 0:\n",
|
|
" with torch.no_grad():\n",
|
|
" logs = {\n",
|
|
" \"actor_loss\": logs_dict[\"actor_loss\"].mean(),\n",
|
|
" \"qf_loss\": logs_dict[\"qf_loss\"].mean(),\n",
|
|
" \"qf_max\": logs_dict[\"qf_max\"].mean(),\n",
|
|
" \"qf_min\": logs_dict[\"qf_min\"].mean(),\n",
|
|
" \"actor_grad_norm\": logs_dict[\"actor_grad_norm\"].mean(),\n",
|
|
" \"critic_grad_norm\": logs_dict[\"critic_grad_norm\"].mean(),\n",
|
|
" \"buffer_rewards\": logs_dict[\"buffer_rewards\"].mean(),\n",
|
|
" \"env_rewards\": rewards.mean(),\n",
|
|
" }\n",
|
|
"\n",
|
|
" if args.eval_interval > 0 and global_step % args.eval_interval == 0:\n",
|
|
" eval_avg_return, eval_avg_length = evaluate()\n",
|
|
" if env_type in [\"humanoid_bench\", \"isaaclab\"]:\n",
|
|
" # NOTE: Hacky way of evaluating performance, but just works\n",
|
|
" obs = envs.reset()\n",
|
|
" logs[\"eval_avg_return\"] = eval_avg_return\n",
|
|
" logs[\"eval_avg_length\"] = eval_avg_length\n",
|
|
"\n",
|
|
" if args.render_interval > 0 and global_step % args.render_interval == 0:\n",
|
|
" renders = render_with_rollout()\n",
|
|
" print_logs = {\n",
|
|
" k: v.item() if isinstance(v, torch.Tensor) else v\n",
|
|
" for k, v in logs.items()\n",
|
|
" }\n",
|
|
" for k, v in print_logs.items():\n",
|
|
" print(f\"{k}: {v:.4f}\")\n",
|
|
" update_video_display(renders, fps=30)\n",
|
|
" if use_wandb:\n",
|
|
" wandb.log(\n",
|
|
" {\n",
|
|
" \"render_video\": wandb.Video(\n",
|
|
" np.array(renders).transpose(\n",
|
|
" 0, 3, 1, 2\n",
|
|
" ), # Convert to (T, C, H, W) format\n",
|
|
" fps=30,\n",
|
|
" format=\"gif\",\n",
|
|
" )\n",
|
|
" },\n",
|
|
" step=global_step,\n",
|
|
" )\n",
|
|
" if use_wandb:\n",
|
|
" wandb.log(\n",
|
|
" {\n",
|
|
" \"frame\": global_step * args.num_envs,\n",
|
|
" **logs,\n",
|
|
" },\n",
|
|
" step=global_step,\n",
|
|
" )\n",
|
|
"\n",
|
|
" if (\n",
|
|
" args.save_interval > 0\n",
|
|
" and global_step > 0\n",
|
|
" and global_step % args.save_interval == 0\n",
|
|
" ):\n",
|
|
" save_params(\n",
|
|
" global_step,\n",
|
|
" actor,\n",
|
|
" qnet,\n",
|
|
" qnet_target,\n",
|
|
" obs_normalizer,\n",
|
|
" critic_obs_normalizer,\n",
|
|
" args,\n",
|
|
" f\"models/{run_name}_{global_step}.pt\",\n",
|
|
" )\n",
|
|
"\n",
|
|
" global_step += 1\n",
|
|
" pbar.update(1)\n",
|
|
"\n",
|
|
"save_params(\n",
|
|
" global_step,\n",
|
|
" actor,\n",
|
|
" qnet,\n",
|
|
" qnet_target,\n",
|
|
" obs_normalizer,\n",
|
|
" critic_obs_normalizer,\n",
|
|
" args,\n",
|
|
" f\"models/{run_name}_final.pt\",\n",
|
|
")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "fasttd3_hb",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.17"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|