603 lines
21 KiB
Python
603 lines
21 KiB
Python
import os
|
|
import sys
|
|
|
|
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
if sys.platform != "darwin":
|
|
os.environ["MUJOCO_GL"] = "egl"
|
|
else:
|
|
os.environ["MUJOCO_GL"] = "glfw"
|
|
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
|
|
|
|
import random
|
|
import time
|
|
|
|
import tqdm
|
|
import wandb
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.amp import autocast, GradScaler
|
|
|
|
from tensordict import TensorDict, from_module
|
|
|
|
from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params
|
|
from hyperparams import get_args
|
|
from fast_td3 import Actor, Critic
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
try:
|
|
import jax.numpy as jnp
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
print(args)
|
|
run_name = f"{args.env_name}__{args.exp_name}__{args.seed}"
|
|
|
|
amp_enabled = args.amp and args.cuda and torch.cuda.is_available()
|
|
amp_device_type = (
|
|
"cuda"
|
|
if args.cuda and torch.cuda.is_available()
|
|
else "mps" if args.cuda and torch.backends.mps.is_available() else "cpu"
|
|
)
|
|
amp_dtype = torch.bfloat16 if args.amp_dtype == "bf16" else torch.float16
|
|
|
|
scaler = GradScaler(enabled=amp_enabled and amp_dtype == torch.float16)
|
|
|
|
if args.use_wandb:
|
|
wandb.init(
|
|
project=args.project,
|
|
name=run_name,
|
|
config=vars(args),
|
|
save_code=True,
|
|
)
|
|
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
torch.backends.cudnn.deterministic = args.torch_deterministic
|
|
|
|
if not args.cuda:
|
|
device = torch.device("cpu")
|
|
else:
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{args.device_rank}")
|
|
elif torch.backends.mps.is_available():
|
|
device = torch.device(f"mps:{args.device_rank}")
|
|
else:
|
|
raise ValueError("No GPU available")
|
|
print(f"Using device: {device}")
|
|
|
|
if args.env_name.startswith("h1hand-") or args.env_name.startswith("h1-"):
|
|
from environments.humanoid_bench_env import HumanoidBenchEnv
|
|
|
|
env_type = "humanoid_bench"
|
|
envs = HumanoidBenchEnv(args.env_name, args.num_envs, device=device)
|
|
eval_envs = envs
|
|
render_env = HumanoidBenchEnv(
|
|
args.env_name, 1, render_mode="rgb_array", device=device
|
|
)
|
|
elif args.env_name.startswith("Isaac-"):
|
|
from environments.isaaclab_env import IsaacLabEnv
|
|
|
|
env_type = "isaaclab"
|
|
envs = IsaacLabEnv(
|
|
args.env_name,
|
|
device.type,
|
|
args.num_envs,
|
|
args.seed,
|
|
action_bounds=args.action_bounds,
|
|
)
|
|
eval_envs = envs
|
|
render_env = envs
|
|
else:
|
|
from environments.mujoco_playground_env import make_env
|
|
|
|
# TODO: Check if re-using same envs for eval could reduce memory usage
|
|
env_type = "mujoco_playground"
|
|
envs, eval_envs, render_env = make_env(
|
|
args.env_name,
|
|
args.seed,
|
|
args.num_envs,
|
|
args.num_eval_envs,
|
|
args.device_rank,
|
|
use_tuned_reward=args.use_tuned_reward,
|
|
use_domain_randomization=args.use_domain_randomization,
|
|
use_push_randomization=args.use_push_randomization,
|
|
)
|
|
|
|
n_act = envs.num_actions
|
|
n_obs = envs.num_obs if type(envs.num_obs) == int else envs.num_obs[0]
|
|
if envs.asymmetric_obs:
|
|
n_critic_obs = (
|
|
envs.num_privileged_obs
|
|
if type(envs.num_privileged_obs) == int
|
|
else envs.num_privileged_obs[0]
|
|
)
|
|
else:
|
|
n_critic_obs = n_obs
|
|
action_low, action_high = -1.0, 1.0
|
|
|
|
if args.obs_normalization:
|
|
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
|
|
critic_obs_normalizer = EmpiricalNormalization(
|
|
shape=n_critic_obs, device=device
|
|
)
|
|
else:
|
|
obs_normalizer = nn.Identity()
|
|
critic_obs_normalizer = nn.Identity()
|
|
|
|
actor = Actor(
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
num_envs=args.num_envs,
|
|
device=device,
|
|
init_scale=args.init_scale,
|
|
hidden_dim=args.actor_hidden_dim,
|
|
)
|
|
actor_detach = Actor(
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
num_envs=args.num_envs,
|
|
device=device,
|
|
init_scale=args.init_scale,
|
|
hidden_dim=args.actor_hidden_dim,
|
|
)
|
|
# Copy params to actor_detach without grad
|
|
from_module(actor).data.to_module(actor_detach)
|
|
policy = actor_detach.explore
|
|
|
|
qnet = Critic(
|
|
n_obs=n_critic_obs,
|
|
n_act=n_act,
|
|
num_atoms=args.num_atoms,
|
|
v_min=args.v_min,
|
|
v_max=args.v_max,
|
|
hidden_dim=args.critic_hidden_dim,
|
|
device=device,
|
|
)
|
|
qnet_target = Critic(
|
|
n_obs=n_critic_obs,
|
|
n_act=n_act,
|
|
num_atoms=args.num_atoms,
|
|
v_min=args.v_min,
|
|
v_max=args.v_max,
|
|
hidden_dim=args.critic_hidden_dim,
|
|
device=device,
|
|
)
|
|
qnet_target.load_state_dict(qnet.state_dict())
|
|
|
|
q_optimizer = optim.AdamW(
|
|
list(qnet.parameters()),
|
|
lr=args.critic_learning_rate,
|
|
weight_decay=args.weight_decay,
|
|
)
|
|
actor_optimizer = optim.AdamW(
|
|
list(actor.parameters()),
|
|
lr=args.actor_learning_rate,
|
|
weight_decay=args.weight_decay,
|
|
)
|
|
|
|
rb = SimpleReplayBuffer(
|
|
n_env=args.num_envs,
|
|
buffer_size=args.buffer_size,
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
n_critic_obs=n_critic_obs,
|
|
asymmetric_obs=envs.asymmetric_obs,
|
|
n_steps=args.num_steps,
|
|
gamma=args.gamma,
|
|
device=device,
|
|
)
|
|
|
|
policy_noise = args.policy_noise
|
|
noise_clip = args.noise_clip
|
|
|
|
def evaluate():
|
|
obs_normalizer.eval()
|
|
num_eval_envs = eval_envs.num_envs
|
|
episode_returns = torch.zeros(num_eval_envs, device=device)
|
|
episode_lengths = torch.zeros(num_eval_envs, device=device)
|
|
done_masks = torch.zeros(num_eval_envs, dtype=torch.bool, device=device)
|
|
|
|
if env_type == "isaaclab":
|
|
obs = eval_envs.reset(random_start_init=False)
|
|
else:
|
|
obs = eval_envs.reset()
|
|
|
|
# Run for a fixed number of steps
|
|
for _ in range(eval_envs.max_episode_steps):
|
|
with torch.no_grad(), autocast(
|
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
|
):
|
|
obs = normalize_obs(obs)
|
|
actions = actor(obs)
|
|
|
|
next_obs, rewards, dones, _ = eval_envs.step(actions.float())
|
|
episode_returns = torch.where(
|
|
~done_masks, episode_returns + rewards, episode_returns
|
|
)
|
|
episode_lengths = torch.where(
|
|
~done_masks, episode_lengths + 1, episode_lengths
|
|
)
|
|
done_masks = torch.logical_or(done_masks, dones)
|
|
if done_masks.all():
|
|
break
|
|
obs = next_obs
|
|
|
|
obs_normalizer.train()
|
|
return episode_returns.mean().item(), episode_lengths.mean().item()
|
|
|
|
def render_with_rollout():
|
|
obs_normalizer.eval()
|
|
|
|
# Quick rollout for rendering
|
|
if env_type == "humanoid_bench":
|
|
obs = render_env.reset()
|
|
renders = [render_env.render()]
|
|
elif env_type == "isaaclab":
|
|
raise NotImplementedError(
|
|
"We don't support rendering for IsaacLab environments"
|
|
)
|
|
else:
|
|
obs = render_env.reset()
|
|
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
|
|
renders = [render_env.state]
|
|
for i in range(render_env.max_episode_steps):
|
|
with torch.no_grad(), autocast(
|
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
|
):
|
|
obs = normalize_obs(obs)
|
|
actions = actor(obs)
|
|
next_obs, _, done, _ = render_env.step(actions.float())
|
|
if env_type == "mujoco_playground":
|
|
render_env.state.info["command"] = jnp.array([[1.0, 0.0, 0.0]])
|
|
if i % 2 == 0:
|
|
if env_type == "humanoid_bench":
|
|
renders.append(render_env.render())
|
|
else:
|
|
renders.append(render_env.state)
|
|
if done.any():
|
|
break
|
|
obs = next_obs
|
|
|
|
if env_type == "mujoco_playground":
|
|
renders = render_env.render_trajectory(renders)
|
|
|
|
obs_normalizer.train()
|
|
return renders
|
|
|
|
def update_main(data, logs_dict):
|
|
with autocast(
|
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
|
):
|
|
observations = data["observations"]
|
|
next_observations = data["next"]["observations"]
|
|
if envs.asymmetric_obs:
|
|
critic_observations = data["critic_observations"]
|
|
next_critic_observations = data["next"]["critic_observations"]
|
|
else:
|
|
critic_observations = observations
|
|
next_critic_observations = next_observations
|
|
actions = data["actions"]
|
|
rewards = data["next"]["rewards"]
|
|
dones = data["next"]["dones"].bool()
|
|
truncations = data["next"]["truncations"].bool()
|
|
if args.disable_bootstrap:
|
|
bootstrap = (~dones).float()
|
|
else:
|
|
bootstrap = (truncations | ~dones).float()
|
|
|
|
clipped_noise = torch.randn_like(actions)
|
|
clipped_noise = clipped_noise.mul(policy_noise).clamp(
|
|
-noise_clip, noise_clip
|
|
)
|
|
|
|
next_state_actions = (actor(next_observations) + clipped_noise).clamp(
|
|
action_low, action_high
|
|
)
|
|
|
|
with torch.no_grad():
|
|
qf1_next_target_projected, qf2_next_target_projected = (
|
|
qnet_target.projection(
|
|
next_critic_observations,
|
|
next_state_actions,
|
|
rewards,
|
|
bootstrap,
|
|
args.gamma,
|
|
)
|
|
)
|
|
qf1_next_target_value = qnet_target.get_value(qf1_next_target_projected)
|
|
qf2_next_target_value = qnet_target.get_value(qf2_next_target_projected)
|
|
if args.use_cdq:
|
|
qf_next_target_dist = torch.where(
|
|
qf1_next_target_value.unsqueeze(1)
|
|
< qf2_next_target_value.unsqueeze(1),
|
|
qf1_next_target_projected,
|
|
qf2_next_target_projected,
|
|
)
|
|
qf1_next_target_dist = qf2_next_target_dist = qf_next_target_dist
|
|
else:
|
|
qf1_next_target_dist, qf2_next_target_dist = (
|
|
qf1_next_target_projected,
|
|
qf2_next_target_projected,
|
|
)
|
|
|
|
qf1, qf2 = qnet(critic_observations, actions)
|
|
qf1_loss = -torch.sum(
|
|
qf1_next_target_dist * F.log_softmax(qf1, dim=1), dim=1
|
|
).mean()
|
|
qf2_loss = -torch.sum(
|
|
qf2_next_target_dist * F.log_softmax(qf2, dim=1), dim=1
|
|
).mean()
|
|
qf_loss = qf1_loss + qf2_loss
|
|
|
|
q_optimizer.zero_grad(set_to_none=True)
|
|
scaler.scale(qf_loss).backward()
|
|
scaler.unscale_(q_optimizer)
|
|
|
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
qnet.parameters(),
|
|
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
|
)
|
|
scaler.step(q_optimizer)
|
|
scaler.update()
|
|
|
|
logs_dict["buffer_rewards"] = rewards.mean()
|
|
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
|
logs_dict["qf_loss"] = qf_loss.detach()
|
|
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
|
|
logs_dict["qf_min"] = qf1_next_target_value.min().detach()
|
|
return logs_dict
|
|
|
|
def update_pol(data, logs_dict):
|
|
with autocast(
|
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
|
):
|
|
critic_observations = (
|
|
data["critic_observations"]
|
|
if envs.asymmetric_obs
|
|
else data["observations"]
|
|
)
|
|
|
|
qf1, qf2 = qnet(critic_observations, actor(data["observations"]))
|
|
qf1_value = qnet.get_value(F.softmax(qf1, dim=1))
|
|
qf2_value = qnet.get_value(F.softmax(qf2, dim=1))
|
|
if args.use_cdq:
|
|
qf_value = torch.minimum(qf1_value, qf2_value)
|
|
else:
|
|
qf_value = (qf1_value + qf2_value) / 2.0
|
|
actor_loss = -qf_value.mean()
|
|
|
|
actor_optimizer.zero_grad(set_to_none=True)
|
|
scaler.scale(actor_loss).backward()
|
|
scaler.unscale_(actor_optimizer)
|
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
actor.parameters(),
|
|
max_norm=args.max_grad_norm if args.max_grad_norm > 0 else float("inf"),
|
|
)
|
|
scaler.step(actor_optimizer)
|
|
scaler.update()
|
|
logs_dict["actor_grad_norm"] = actor_grad_norm.detach()
|
|
logs_dict["actor_loss"] = actor_loss.detach()
|
|
return logs_dict
|
|
|
|
if args.compile:
|
|
mode = None
|
|
update_main = torch.compile(update_main, mode=mode)
|
|
update_pol = torch.compile(update_pol, mode=mode)
|
|
policy = torch.compile(policy, mode=mode)
|
|
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
|
|
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
|
|
else:
|
|
normalize_obs = obs_normalizer.forward
|
|
normalize_critic_obs = critic_obs_normalizer.forward
|
|
|
|
if envs.asymmetric_obs:
|
|
obs, critic_obs = envs.reset_with_critic_obs()
|
|
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
|
|
else:
|
|
obs = envs.reset()
|
|
if args.checkpoint_path:
|
|
# Load checkpoint if specified
|
|
torch_checkpoint = torch.load(
|
|
f"{args.checkpoint_path}", map_location=device, weights_only=False
|
|
)
|
|
actor.load_state_dict(torch_checkpoint["actor_state_dict"])
|
|
obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
|
|
critic_obs_normalizer.load_state_dict(
|
|
torch_checkpoint["critic_obs_normalizer_state"]
|
|
)
|
|
qnet.load_state_dict(torch_checkpoint["qnet_state_dict"])
|
|
qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
|
|
global_step = torch_checkpoint["global_step"]
|
|
else:
|
|
global_step = 0
|
|
|
|
dones = None
|
|
pbar = tqdm.tqdm(total=args.total_timesteps, initial=global_step)
|
|
start_time = None
|
|
desc = ""
|
|
|
|
while global_step < args.total_timesteps:
|
|
logs_dict = TensorDict()
|
|
if (
|
|
start_time is None
|
|
and global_step >= args.measure_burnin + args.learning_starts
|
|
):
|
|
start_time = time.time()
|
|
measure_burnin = global_step
|
|
|
|
with torch.no_grad(), autocast(
|
|
device_type=amp_device_type, dtype=amp_dtype, enabled=amp_enabled
|
|
):
|
|
norm_obs = normalize_obs(obs)
|
|
actions = policy(obs=norm_obs, dones=dones)
|
|
|
|
next_obs, rewards, dones, infos = envs.step(actions.float())
|
|
truncations = infos["time_outs"]
|
|
|
|
if envs.asymmetric_obs:
|
|
next_critic_obs = infos["observations"]["critic"]
|
|
|
|
# Compute 'true' next_obs and next_critic_obs for saving
|
|
true_next_obs = torch.where(
|
|
dones[:, None] > 0, infos["observations"]["raw"]["obs"], next_obs
|
|
)
|
|
if envs.asymmetric_obs:
|
|
true_next_critic_obs = torch.where(
|
|
dones[:, None] > 0,
|
|
infos["observations"]["raw"]["critic_obs"],
|
|
next_critic_obs,
|
|
)
|
|
transition = TensorDict(
|
|
{
|
|
"observations": obs,
|
|
"actions": torch.as_tensor(actions, device=device, dtype=torch.float),
|
|
"next": {
|
|
"observations": true_next_obs,
|
|
"rewards": torch.as_tensor(
|
|
rewards, device=device, dtype=torch.float
|
|
),
|
|
"truncations": truncations.long(),
|
|
"dones": dones.long(),
|
|
},
|
|
},
|
|
batch_size=(envs.num_envs,),
|
|
device=device,
|
|
)
|
|
if envs.asymmetric_obs:
|
|
transition["critic_observations"] = critic_obs
|
|
transition["next"]["critic_observations"] = true_next_critic_obs
|
|
|
|
obs = next_obs
|
|
if envs.asymmetric_obs:
|
|
critic_obs = next_critic_obs
|
|
|
|
rb.extend(transition)
|
|
|
|
batch_size = args.batch_size // args.num_envs
|
|
if global_step > args.learning_starts:
|
|
for i in range(args.num_updates):
|
|
data = rb.sample(batch_size)
|
|
data["observations"] = normalize_obs(data["observations"])
|
|
data["next"]["observations"] = normalize_obs(
|
|
data["next"]["observations"]
|
|
)
|
|
if envs.asymmetric_obs:
|
|
data["critic_observations"] = normalize_critic_obs(
|
|
data["critic_observations"]
|
|
)
|
|
data["next"]["critic_observations"] = normalize_critic_obs(
|
|
data["next"]["critic_observations"]
|
|
)
|
|
logs_dict = update_main(data, logs_dict)
|
|
if args.num_updates > 1:
|
|
if i % args.policy_frequency == 1:
|
|
logs_dict = update_pol(data, logs_dict)
|
|
else:
|
|
if global_step % args.policy_frequency == 0:
|
|
logs_dict = update_pol(data, logs_dict)
|
|
|
|
for param, target_param in zip(
|
|
qnet.parameters(), qnet_target.parameters()
|
|
):
|
|
target_param.data.copy_(
|
|
args.tau * param.data + (1 - args.tau) * target_param.data
|
|
)
|
|
|
|
if global_step % 100 == 0 and start_time is not None:
|
|
speed = (global_step - measure_burnin) / (time.time() - start_time)
|
|
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
|
|
with torch.no_grad():
|
|
logs = {
|
|
"actor_loss": logs_dict["actor_loss"].mean(),
|
|
"qf_loss": logs_dict["qf_loss"].mean(),
|
|
"qf_max": logs_dict["qf_max"].mean(),
|
|
"qf_min": logs_dict["qf_min"].mean(),
|
|
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
|
|
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
|
|
"buffer_rewards": logs_dict["buffer_rewards"].mean(),
|
|
"env_rewards": rewards.mean(),
|
|
}
|
|
|
|
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
|
print(f"Evaluating at global step {global_step}")
|
|
eval_avg_return, eval_avg_length = evaluate()
|
|
if env_type in ["humanoid_bench", "isaaclab"]:
|
|
# NOTE: Hacky way of evaluating performance, but just works
|
|
obs = envs.reset()
|
|
logs["eval_avg_return"] = eval_avg_return
|
|
logs["eval_avg_length"] = eval_avg_length
|
|
|
|
if (
|
|
args.render_interval > 0
|
|
and global_step % args.render_interval == 0
|
|
):
|
|
renders = render_with_rollout()
|
|
if args.use_wandb:
|
|
wandb.log(
|
|
{
|
|
"render_video": wandb.Video(
|
|
np.array(renders).transpose(
|
|
0, 3, 1, 2
|
|
), # Convert to (T, C, H, W) format
|
|
fps=30,
|
|
format="gif",
|
|
)
|
|
},
|
|
step=global_step,
|
|
)
|
|
if args.use_wandb:
|
|
wandb.log(
|
|
{
|
|
"speed": speed,
|
|
"frame": global_step * args.num_envs,
|
|
**logs,
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if (
|
|
args.save_interval > 0
|
|
and global_step > 0
|
|
and global_step % args.save_interval == 0
|
|
):
|
|
print(f"Saving model at global step {global_step}")
|
|
save_params(
|
|
global_step,
|
|
actor,
|
|
qnet,
|
|
qnet_target,
|
|
obs_normalizer,
|
|
critic_obs_normalizer,
|
|
args,
|
|
f"models/{run_name}_{global_step}.pt",
|
|
)
|
|
|
|
global_step += 1
|
|
pbar.update(1)
|
|
|
|
save_params(
|
|
global_step,
|
|
actor,
|
|
qnet,
|
|
qnet_target,
|
|
obs_normalizer,
|
|
critic_obs_normalizer,
|
|
args,
|
|
f"models/{run_name}_final.pt",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|