736 lines
26 KiB
Python
736 lines
26 KiB
Python
from dataclasses import dataclass, replace
|
|
import functools
|
|
import os
|
|
import random
|
|
import sys
|
|
import copy
|
|
import time
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
import wandb
|
|
|
|
from reppo_alg.torchrl.reppo import EmpiricalNormalization, hl_gauss
|
|
|
|
try:
|
|
# Required for avoiding IsaacGym import error
|
|
import isaacgym
|
|
except ImportError:
|
|
pass
|
|
|
|
import hydra
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torchinfo import summary
|
|
from tensordict import TensorDict
|
|
from torch.amp import GradScaler
|
|
from reppo_alg.torchrl.envs import make_envs
|
|
from reppo_alg.network_utils.torch_models import Actor, Critic
|
|
|
|
|
|
torch.set_float32_matmul_precision("medium")
|
|
os.environ["TORCHDYNAMO_INLINE_INBUILT_NN_MODULES"] = "1"
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
if sys.platform != "darwin":
|
|
os.environ["MUJOCO_GL"] = "egl"
|
|
else:
|
|
os.environ["MUJOCO_GL"] = "glfw"
|
|
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class TrainState:
|
|
device: torch.device
|
|
obs: torch.Tensor
|
|
critic_obs: torch.Tensor
|
|
actor: Actor
|
|
old_actor: Actor
|
|
critic: Critic
|
|
normalizer: EmpiricalNormalization
|
|
critic_normalizer: EmpiricalNormalization
|
|
actor_optimizer: optim.Optimizer
|
|
critic_optimizer: optim.Optimizer
|
|
scaler: GradScaler
|
|
|
|
def compile(self):
|
|
self.actor.compile()
|
|
self.old_actor.compile()
|
|
self.critic.compile()
|
|
self.normalizer.compile()
|
|
self.critic_normalizer.compile()
|
|
|
|
|
|
def get_autocast_context(cfg: DictConfig):
|
|
amp_enabled = (
|
|
cfg.platform.amp_enabled and cfg.platform.cuda and torch.cuda.is_available()
|
|
)
|
|
amp_device = (
|
|
"cuda"
|
|
if cfg.platform.cuda and torch.cuda.is_available()
|
|
else "mps"
|
|
if cfg.platform.cuda and torch.backends.mps.is_available()
|
|
else "cpu"
|
|
)
|
|
amp_dtype = torch.bfloat16 if cfg.platform.amp_dtype == "bf16" else torch.float32
|
|
return functools.partial(
|
|
torch.amp.autocast,
|
|
device_type=amp_device,
|
|
dtype=amp_dtype,
|
|
enabled=amp_enabled,
|
|
)
|
|
|
|
|
|
def make_collect_fn(cfg: DictConfig, env):
|
|
autocast = get_autocast_context(cfg)
|
|
asymmetric_obs = env.asymmetric_obs
|
|
|
|
def collect_fn(
|
|
train_state: TrainState,
|
|
) -> tuple[TrainState, TensorDict, list[dict]]:
|
|
transitions = []
|
|
info_list = []
|
|
obs = train_state.obs
|
|
critic_obs = train_state.critic_obs
|
|
|
|
for _ in range(cfg.hyperparameters.num_steps):
|
|
with autocast():
|
|
norm_obs = train_state.normalizer(obs)
|
|
norm_critic_obs = train_state.critic_normalizer(critic_obs)
|
|
with torch.inference_mode():
|
|
pi, _, _, _ = train_state.actor(norm_obs)
|
|
actions = pi.sample()
|
|
|
|
next_obs, rewards, dones, truncations, infos = env.step(actions)
|
|
|
|
if asymmetric_obs:
|
|
next_critic_obs = infos["observations"]["critic"]
|
|
else:
|
|
next_critic_obs = next_obs
|
|
|
|
with (
|
|
torch.inference_mode(),
|
|
autocast(),
|
|
):
|
|
if (
|
|
cfg.env.get("has_final_obs", False)
|
|
and cfg.env.get("partial_reset", False)
|
|
and "final_observation" in infos
|
|
):
|
|
_next_obs = infos["final_observation"]
|
|
_next_critic_obs = _next_obs
|
|
else:
|
|
_next_obs = next_obs
|
|
_next_critic_obs = next_critic_obs
|
|
norm_next_obs = train_state.normalizer(_next_obs)
|
|
next_pi, _, temperature, _ = train_state.actor(norm_next_obs)
|
|
next_actions = next_pi.sample()
|
|
next_log_probs = next_pi.log_prob(
|
|
next_actions.clip(-1 + 1e-6, 1 - 1e-6)
|
|
).sum(-1)
|
|
norm_next_critic_obs = train_state.critic_normalizer(_next_critic_obs)
|
|
next_value, _, _, next_embedding = train_state.critic(
|
|
norm_next_critic_obs, next_actions
|
|
)
|
|
rewards = (
|
|
rewards - cfg.hyperparameters.gamma * next_log_probs * temperature
|
|
)
|
|
|
|
transitions.append(
|
|
TensorDict(
|
|
{
|
|
"observations": norm_obs,
|
|
"critic_observations": norm_critic_obs,
|
|
"actions": actions,
|
|
"log_probs": pi.log_prob(actions.clip(-0.999, 0.999)).sum(-1),
|
|
"rewards": rewards.unsqueeze(-1),
|
|
"next_embeddings": next_embedding,
|
|
"next_values": next_value.unsqueeze(-1),
|
|
"dones": dones.unsqueeze(-1).float(),
|
|
"truncations": truncations.unsqueeze(-1).float(),
|
|
},
|
|
batch_size=(env.num_envs,),
|
|
)
|
|
)
|
|
info_list.append(infos)
|
|
obs = next_obs
|
|
critic_obs = next_critic_obs
|
|
|
|
train_state = replace(train_state, obs=obs, critic_obs=critic_obs)
|
|
return (
|
|
train_state,
|
|
torch.stack(transitions, dim=0),
|
|
info_list,
|
|
)
|
|
|
|
return collect_fn
|
|
|
|
|
|
def make_postprocess_fn(cfg: DictConfig, env):
|
|
@torch.compiler.disable()
|
|
def compute_gve(rewards, dones, truncated, next_values, device: torch.device):
|
|
gves = []
|
|
last_gve = 0
|
|
truncated[-1] = 1.0
|
|
for t in reversed(range(cfg.hyperparameters.num_steps)):
|
|
lambda_sum = (
|
|
cfg.hyperparameters.lmbda * last_gve
|
|
+ (1.0 - cfg.hyperparameters.lmbda) * next_values[t]
|
|
)
|
|
delta = cfg.hyperparameters.gamma * torch.where(
|
|
truncated[t].bool(), next_values[t], (1.0 - dones[t]) * lambda_sum
|
|
)
|
|
last_gve = rewards[t] + delta
|
|
gves.insert(0, last_gve)
|
|
return gves
|
|
|
|
def postprocess(train_state: TrainState, transition: TensorDict):
|
|
gve = compute_gve(
|
|
rewards=transition["rewards"],
|
|
dones=transition["dones"],
|
|
truncated=transition["truncations"],
|
|
next_values=transition["next_values"],
|
|
device=train_state.device,
|
|
)
|
|
|
|
# Flatten all time and environment dimensions into a single batch dimension
|
|
data = TensorDict(
|
|
{
|
|
"observations": transition["observations"],
|
|
"critic_observations": transition["critic_observations"],
|
|
"actions": transition["actions"],
|
|
"rewards": transition["rewards"],
|
|
"next_embeddings": transition["next_embeddings"],
|
|
"next_values": transition["next_values"],
|
|
"dones": transition["dones"],
|
|
"truncations": transition["truncations"],
|
|
"gve": torch.stack(gve),
|
|
},
|
|
batch_size=(
|
|
cfg.hyperparameters.num_steps,
|
|
cfg.hyperparameters.num_envs,
|
|
),
|
|
device=train_state.device,
|
|
)
|
|
return data.float().flatten(0, 1).detach()
|
|
|
|
return postprocess
|
|
|
|
|
|
def make_critic_update_fn(cfg: DictConfig, train_state: TrainState):
|
|
autocast = get_autocast_context(cfg)
|
|
|
|
def update(data: TensorDict):
|
|
qnet = train_state.critic
|
|
q_optimizer = train_state.critic_optimizer
|
|
|
|
with autocast():
|
|
critic_observations = data["critic_observations"]
|
|
actions = data["actions"]
|
|
targets = data["gve"]
|
|
target_embeddings = data["next_embeddings"]
|
|
truncations = data["truncations"].squeeze(-1)
|
|
if cfg.env.get("partial_reset", False):
|
|
truncation_mask = torch.ones_like(
|
|
truncations, dtype=torch.bool, device=train_state.device
|
|
)
|
|
else:
|
|
truncation_mask = 1.0 - truncations
|
|
qf_target_dist = hl_gauss(
|
|
targets,
|
|
cfg.hyperparameters.vmin,
|
|
cfg.hyperparameters.vmax,
|
|
cfg.hyperparameters.num_bins,
|
|
)
|
|
|
|
_, qf1, embedding, _ = qnet(critic_observations, actions)
|
|
qf_loss = -(
|
|
truncation_mask
|
|
* torch.sum(qf_target_dist * F.log_softmax(qf1, dim=-1), dim=-1)
|
|
).mean()
|
|
embedding_loss = (
|
|
truncation_mask
|
|
* F.mse_loss(
|
|
embedding,
|
|
target_embeddings,
|
|
reduction="none",
|
|
).mean(dim=-1)
|
|
).mean()
|
|
|
|
qf_loss = qf_loss + cfg.hyperparameters.aux_loss_mult * embedding_loss
|
|
|
|
q_optimizer.zero_grad(set_to_none=True)
|
|
train_state.scaler.scale(qf_loss).backward()
|
|
train_state.scaler.unscale_(q_optimizer)
|
|
|
|
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
qnet.parameters(), max_norm=cfg.hyperparameters.max_grad_norm
|
|
)
|
|
train_state.scaler.step(q_optimizer)
|
|
train_state.scaler.update()
|
|
logs_dict = {
|
|
"critic_grad_norm": critic_grad_norm.detach(),
|
|
"qf_loss": qf_loss.detach(),
|
|
"qf_max": targets.max().detach(),
|
|
"qf_min": targets.min().detach(),
|
|
"qf_mean": targets.mean().detach(),
|
|
"embedding_loss": embedding_loss.detach(),
|
|
}
|
|
return logs_dict
|
|
|
|
return update
|
|
|
|
|
|
def make_actor_update_fn(cfg: DictConfig, train_state: TrainState):
|
|
autocast = get_autocast_context(cfg)
|
|
|
|
def update(data: TensorDict):
|
|
actor = train_state.actor
|
|
old_actor = train_state.old_actor
|
|
qnet = train_state.critic
|
|
actor_optimizer = train_state.actor_optimizer
|
|
scaler = train_state.scaler
|
|
critic_obs = data["critic_observations"]
|
|
with autocast():
|
|
pi, _, temperature, beta = actor(data["observations"])
|
|
actions = pi.rsample()
|
|
log_probs = pi.log_prob(actions.clip(-1 + 1e-6, 1 - 1e-6)).sum(-1)
|
|
entropy = -log_probs
|
|
qf, _, _, _ = qnet(critic_obs, actions)
|
|
actor_loss = -qf + temperature.detach() * log_probs
|
|
|
|
# compute KL
|
|
old_pi, _, _, _ = old_actor(data["observations"])
|
|
old_pi_actions = old_pi.sample((16,)).clip(-1 + 1e-6, 1 - 1e-6)
|
|
old_log_probs = old_pi.log_prob(old_pi_actions).sum(-1).mean(0)
|
|
new_pi_log_probs = pi.log_prob(old_pi_actions).sum(-1).mean(0)
|
|
kl = old_log_probs - new_pi_log_probs
|
|
|
|
if cfg.hyperparameters.actor_kl_clip_mode == "clipped":
|
|
actor_loss = torch.where(
|
|
kl < cfg.hyperparameters.kl_bound,
|
|
actor_loss,
|
|
kl * beta.detach(),
|
|
).mean()
|
|
elif cfg.hyperparameters.actor_kl_clip_mode == "full":
|
|
actor_loss = actor_loss + kl * beta.detach()
|
|
elif cfg.hyperparameters.actor_kl_clip_mode == "value":
|
|
actor_loss = actor_loss
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown actor kl clip mode: {cfg.hyperparameters.actor_kl_clip_mode}"
|
|
)
|
|
|
|
# temperature updates
|
|
target_entropy = (
|
|
actions.shape[-1] * cfg.hyperparameters.ent_target_mult
|
|
) # -0.5 * np.prod(envs.action_space.shape)
|
|
entropy_loss = (target_entropy + entropy).detach().mean() * temperature
|
|
|
|
lagrangian_loss = (
|
|
-beta * (kl - cfg.hyperparameters.kl_bound).mean().detach()
|
|
)
|
|
|
|
actor_loss = (actor_loss + entropy_loss + lagrangian_loss).mean()
|
|
|
|
actor_optimizer.zero_grad(set_to_none=True)
|
|
scaler.scale(actor_loss).backward()
|
|
scaler.unscale_(actor_optimizer)
|
|
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
actor.parameters(), max_norm=cfg.hyperparameters.max_grad_norm
|
|
)
|
|
scaler.step(actor_optimizer)
|
|
scaler.update()
|
|
logs_dict = {
|
|
"actor_grad_norm": actor_grad_norm.detach(),
|
|
"actor_loss": actor_loss.detach(),
|
|
"kl": kl.detach(),
|
|
"entropy": entropy.detach(),
|
|
"temperature": temperature.detach(),
|
|
"lagrangian": beta.detach(),
|
|
"entropy_loss": entropy_loss.detach(),
|
|
"lagrangian_loss": lagrangian_loss.detach(),
|
|
}
|
|
return logs_dict
|
|
|
|
return update
|
|
|
|
|
|
def make_evaluate_fn(cfg: DictConfig, eval_envs):
|
|
autocast = get_autocast_context(cfg)
|
|
|
|
@torch.inference_mode()
|
|
def evaluate(
|
|
train_state: TrainState, stochastic_eval: bool = False
|
|
) -> tuple[int | float | bool, int | float | bool]:
|
|
train_state.normalizer.eval()
|
|
num_eval_envs = eval_envs.num_envs
|
|
episode_returns = torch.zeros(num_eval_envs, device=train_state.device)
|
|
episode_lengths = torch.zeros(num_eval_envs, device=train_state.device)
|
|
done_masks = torch.zeros(
|
|
num_eval_envs, dtype=torch.bool, device=train_state.device
|
|
)
|
|
|
|
if cfg.env.type == "isaaclab" or cfg.env.asymmetric_observation:
|
|
obs, _ = eval_envs.reset(random_start_init=False)
|
|
else:
|
|
obs = eval_envs.reset()
|
|
|
|
# Run for a fixed number of steps
|
|
for i in range(eval_envs.max_episode_steps):
|
|
with autocast():
|
|
obs = train_state.normalizer(obs)
|
|
action_dist, det_actions, _, _ = train_state.actor(obs)
|
|
if stochastic_eval:
|
|
actions = action_dist.sample()
|
|
else:
|
|
actions = det_actions
|
|
|
|
next_obs, rewards, dones, _, infos = eval_envs.step(actions)
|
|
|
|
episode_returns = torch.where(
|
|
~done_masks, episode_returns + rewards, episode_returns
|
|
)
|
|
episode_lengths = torch.where(
|
|
~done_masks, episode_lengths + 1, episode_lengths
|
|
)
|
|
done_masks = torch.logical_or(done_masks, dones)
|
|
if done_masks.all():
|
|
break
|
|
obs = next_obs
|
|
|
|
train_state.normalizer.train()
|
|
|
|
if cfg.env.type == "maniskill":
|
|
# combine log_infos
|
|
info = {
|
|
"info_return": infos["log_info"]["return"].mean(),
|
|
"episode_len": infos["log_info"]["episode_len"].float().mean(),
|
|
"success": infos["log_info"]["success"].float().mean(),
|
|
"return": episode_returns.mean().item(),
|
|
}
|
|
else:
|
|
info = {}
|
|
|
|
return episode_returns.mean().item(), episode_lengths.mean().item(), info
|
|
|
|
return evaluate
|
|
|
|
|
|
def configure_platform(cfg: DictConfig) -> DictConfig:
|
|
cfg.platform.amp_enabled = (
|
|
cfg.platform.amp_enabled and cfg.platform.cuda and torch.cuda.is_available()
|
|
)
|
|
cfg.platform.amp_device = (
|
|
"cuda"
|
|
if cfg.platform.cuda and torch.cuda.is_available()
|
|
else "mps"
|
|
if cfg.platform.cuda and torch.backends.mps.is_available()
|
|
else "cpu"
|
|
)
|
|
return cfg
|
|
|
|
|
|
@hydra.main(
|
|
version_base=None,
|
|
config_path="../../config",
|
|
config_name="reppo",
|
|
)
|
|
def main(cfg):
|
|
cfg = configure_platform(cfg)
|
|
run_name = f"{cfg.env.name}_torch_{cfg.seed}"
|
|
|
|
scaler = GradScaler(
|
|
enabled=cfg.platform.amp_enabled and cfg.platform.amp_dtype == torch.float16
|
|
)
|
|
|
|
num_batches = cfg.hyperparameters.num_mini_batches
|
|
batch_size = (
|
|
cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps // num_batches
|
|
)
|
|
|
|
wandb.init(
|
|
project=cfg.wandb.project,
|
|
name=run_name,
|
|
config=OmegaConf.to_container(cfg),
|
|
save_code=True,
|
|
)
|
|
|
|
random.seed(cfg.seed)
|
|
np.random.seed(cfg.seed)
|
|
torch.manual_seed(cfg.seed)
|
|
torch.backends.cudnn.deterministic = cfg.platform.torch_deterministic
|
|
|
|
if not cfg.platform.cuda:
|
|
device = torch.device("cpu")
|
|
else:
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{cfg.platform.device_rank}")
|
|
elif torch.backends.mps.is_available():
|
|
device = torch.device(f"mps:{cfg.platform.device_rank}")
|
|
else:
|
|
raise ValueError("No GPU available")
|
|
print(f"Using device: {device}")
|
|
|
|
envs, eval_envs = make_envs(cfg=cfg, device=device, seed=cfg.seed)
|
|
|
|
n_act = envs.num_actions
|
|
n_obs = envs.num_obs if isinstance(envs.num_obs, int) else envs.num_obs[0]
|
|
if envs.asymmetric_obs:
|
|
n_critic_obs = (
|
|
envs.num_privileged_obs
|
|
if isinstance(envs.num_privileged_obs, int)
|
|
else envs.num_privileged_obs[0]
|
|
)
|
|
else:
|
|
n_critic_obs = n_obs
|
|
|
|
if cfg.hyperparameters.normalize_env:
|
|
obs_normalizer = EmpiricalNormalization(shape=n_obs, device=device)
|
|
critic_obs_normalizer = EmpiricalNormalization(
|
|
shape=n_critic_obs, device=device
|
|
)
|
|
else:
|
|
obs_normalizer = nn.Identity()
|
|
critic_obs_normalizer = nn.Identity()
|
|
|
|
actor = Actor(
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
ent_start=cfg.hyperparameters.ent_start,
|
|
kl_start=cfg.hyperparameters.kl_start,
|
|
hidden_dim=cfg.hyperparameters.actor_hidden_dim,
|
|
use_norm=cfg.hyperparameters.use_actor_norm,
|
|
layers=cfg.hyperparameters.num_actor_layers,
|
|
min_std=cfg.hyperparameters.actor_min_std,
|
|
device=device,
|
|
)
|
|
old_actor = copy.deepcopy(actor)
|
|
qnet = Critic(
|
|
n_obs=n_critic_obs,
|
|
n_act=n_act,
|
|
num_atoms=cfg.hyperparameters.num_bins,
|
|
vmin=cfg.hyperparameters.vmin,
|
|
vmax=cfg.hyperparameters.vmax,
|
|
hidden_dim=cfg.hyperparameters.critic_hidden_dim,
|
|
use_norm=cfg.hyperparameters.use_critic_norm,
|
|
use_encoder_norm=False,
|
|
encoder_layers=cfg.hyperparameters.num_critic_encoder_layers,
|
|
head_layers=cfg.hyperparameters.num_critic_head_layers,
|
|
pred_layers=cfg.hyperparameters.num_critic_pred_layers,
|
|
device=device,
|
|
)
|
|
|
|
q_optimizer = optim.AdamW(
|
|
list(qnet.parameters()),
|
|
lr=torch.tensor(cfg.hyperparameters.lr, device=device),
|
|
)
|
|
actor_optimizer = optim.AdamW(
|
|
list(actor.parameters()),
|
|
lr=torch.tensor(cfg.hyperparameters.lr, device=device),
|
|
)
|
|
|
|
if envs.asymmetric_obs:
|
|
obs, critic_obs = envs.reset_with_critic_obs()
|
|
critic_obs = torch.as_tensor(critic_obs, device=device, dtype=torch.float)
|
|
else:
|
|
obs = envs.reset()
|
|
critic_obs = obs
|
|
|
|
train_state = TrainState(
|
|
obs=obs,
|
|
critic_obs=critic_obs,
|
|
actor=actor,
|
|
old_actor=old_actor,
|
|
critic=qnet,
|
|
normalizer=obs_normalizer,
|
|
critic_normalizer=critic_obs_normalizer,
|
|
actor_optimizer=actor_optimizer,
|
|
critic_optimizer=q_optimizer,
|
|
device=device,
|
|
scaler=scaler,
|
|
)
|
|
|
|
print(
|
|
summary(
|
|
train_state.critic,
|
|
input_data=(critic_obs[:1], torch.zeros((1, n_act), device=device)),
|
|
depth=10,
|
|
)
|
|
)
|
|
print(summary(train_state.actor, input_data=(obs[:1],), depth=10))
|
|
# create functions
|
|
collect_fn = make_collect_fn(cfg, envs)
|
|
postprocess_fn = make_postprocess_fn(cfg, envs)
|
|
update_critic = make_critic_update_fn(cfg, train_state)
|
|
update_actor = make_actor_update_fn(cfg, train_state)
|
|
evaluate = make_evaluate_fn(cfg, eval_envs)
|
|
|
|
if cfg.platform.compile:
|
|
mode = "max-autotune-no-cudagraphs"
|
|
update_critic = torch.compile(update_critic, mode=mode)
|
|
update_actor = torch.compile(update_actor, mode=mode)
|
|
postprocess_fn = torch.compile(postprocess_fn, mode=mode)
|
|
train_state.compile()
|
|
|
|
# TODO: Support checkpoint loading
|
|
# if cfg.checkpoint_path:
|
|
# # Load checkpoint if specified
|
|
# torch_checkpoint = torch.load(
|
|
# f"{cfg.checkpoint_path}", map_location=device, weights_only=False
|
|
# )
|
|
# actor.load_state_dict(torch_checkpoint["actor_state_dict"])
|
|
# obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
|
|
# critic_obs_normalizer.load_state_dict(
|
|
# torch_checkpoint["critic_obs_normalizer_state"]
|
|
# )
|
|
# qnet.load_state_dict(torch_checkpoint["qnet_state_dict"])
|
|
# qnet_target.load_state_dict(torch_checkpoint["qnet_target_state_dict"])
|
|
# global_step = torch_checkpoint["global_step"]
|
|
# else:
|
|
global_step = 0
|
|
total_env_steps = (
|
|
cfg.hyperparameters.total_time_steps
|
|
// (cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps)
|
|
+ 1
|
|
)
|
|
|
|
pbar = tqdm.tqdm(total=cfg.hyperparameters.total_time_steps, initial=global_step)
|
|
start_time = None
|
|
desc = ""
|
|
|
|
eval_interval = total_env_steps // cfg.hyperparameters.num_eval
|
|
stochastic_eval = cfg.env.get("stochastic_eval", False)
|
|
|
|
while global_step < total_env_steps:
|
|
if start_time is None and global_step >= cfg.measure_burnin:
|
|
start_time = time.time()
|
|
measure_burnin = global_step
|
|
|
|
train_state, transition, infos = collect_fn(train_state)
|
|
data = postprocess_fn(train_state, transition)
|
|
|
|
for _ in range(cfg.hyperparameters.num_epochs):
|
|
indices = torch.randperm(
|
|
cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps,
|
|
device=device,
|
|
)
|
|
data = data[indices].contiguous()
|
|
for j in range(num_batches):
|
|
mini_batch = data[j * batch_size : (j + 1) * batch_size]
|
|
critic_logs_dict = update_critic(mini_batch)
|
|
actor_logs_dict = update_actor(mini_batch)
|
|
logs_dict = {
|
|
**critic_logs_dict,
|
|
**actor_logs_dict,
|
|
}
|
|
|
|
for param, target_param in zip(actor.parameters(), old_actor.parameters()):
|
|
target_param.data.copy_(param.data)
|
|
if start_time is not None:
|
|
# @TODO: shouldn't that be env_steps per second?
|
|
speed = (
|
|
cfg.hyperparameters.num_envs
|
|
* cfg.hyperparameters.num_steps
|
|
* (global_step - measure_burnin)
|
|
/ (time.time() - start_time)
|
|
)
|
|
pbar.set_description(f"{speed: 4.4f} sps, " + desc)
|
|
with torch.inference_mode():
|
|
logs = {
|
|
"critic/qf_loss": logs_dict["qf_loss"].mean(),
|
|
"critic/qf_max": logs_dict["qf_max"].mean(),
|
|
"critic/qf_min": logs_dict["qf_min"].mean(),
|
|
"critic/qf_mean": logs_dict["qf_mean"].mean(),
|
|
"critic/embedding_loss": logs_dict["embedding_loss"].mean(),
|
|
"critic/critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
|
|
"actor/actor_loss": logs_dict["actor_loss"].mean(),
|
|
"actor/actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
|
|
"actor/kl": logs_dict["kl"].mean(),
|
|
"actor/entropy": logs_dict["entropy"].mean(),
|
|
"actor/temperature": logs_dict["temperature"].mean(),
|
|
"actor/lagrangian": logs_dict["lagrangian"].mean(),
|
|
"actor/entropy_loss": logs_dict["entropy_loss"].mean(),
|
|
"actor/lagrangian_loss": logs_dict["lagrangian_loss"].mean(),
|
|
"train/rewards_batch": data["rewards"].mean(),
|
|
}
|
|
|
|
if cfg.env.type == "maniskill":
|
|
logs.update(
|
|
{
|
|
"train/return": torch.stack(
|
|
[info["log_info"]["return"] for info in infos]
|
|
).mean(),
|
|
"train/episode_len": torch.stack(
|
|
[info["log_info"]["episode_len"] for info in infos]
|
|
)
|
|
.float()
|
|
.mean(),
|
|
"train/success": torch.stack(
|
|
[info["log_info"]["success"] for info in infos]
|
|
)
|
|
.float()
|
|
.mean(),
|
|
}
|
|
)
|
|
|
|
if eval_interval > 0 and global_step % eval_interval == 0:
|
|
print(f"Evaluating at global step {global_step}")
|
|
if stochastic_eval:
|
|
eval_avg_return, eval_avg_length, stoch_eval_info = evaluate(
|
|
train_state, stochastic_eval=stochastic_eval
|
|
)
|
|
eval_avg_return, eval_avg_length, eval_info = evaluate(
|
|
train_state
|
|
)
|
|
eval_info = {
|
|
**eval_info,
|
|
**{f"stoch/{k}": v for k, v in stoch_eval_info.items()},
|
|
}
|
|
else:
|
|
eval_avg_return, eval_avg_length, eval_info = evaluate(
|
|
train_state
|
|
)
|
|
if cfg.env.type in [
|
|
"humanoid_bench",
|
|
"isaaclab",
|
|
"mtbench",
|
|
]:
|
|
# NOTE: Hacky way of evaluating performance, but just works
|
|
obs, _ = envs.reset()
|
|
logs["eval/avg_return"] = eval_avg_return
|
|
logs["eval/avg_length"] = eval_avg_length
|
|
for key, value in eval_info.items():
|
|
if isinstance(value, torch.Tensor):
|
|
logs[f"eval/{key}"] = value.mean().item()
|
|
elif isinstance(value, np.ndarray):
|
|
logs[f"eval/{key}"] = value.mean()
|
|
else:
|
|
logs[f"eval/{key}"] = value
|
|
print(
|
|
f"Eval return: {eval_avg_return:.2f}, length: {eval_avg_length:.2f}, env steps: {global_step * cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps} success rate: {eval_info.get('success', 0.0):.2f}"
|
|
)
|
|
wandb.log(
|
|
{
|
|
"speed": speed,
|
|
"frame": global_step
|
|
* cfg.hyperparameters.num_envs
|
|
* cfg.hyperparameters.num_steps,
|
|
**logs,
|
|
},
|
|
step=global_step
|
|
* cfg.hyperparameters.num_envs
|
|
* cfg.hyperparameters.num_steps,
|
|
)
|
|
|
|
global_step += 1
|
|
pbar.update(n=cfg.hyperparameters.num_envs * cfg.hyperparameters.num_steps)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|