FastTD3/fast_td3/fast_td3_utils.py
Younggyo Seo 51c55d4a8a
Support Multi-GPU Training (#22)
- Change in isaaclab_env wrapper to explicitly state GPU for each simulation
- Removing jax cache to support multi-gpu environment launch in MuJoCo Playground
- Removing .train() and .eval() in evaluation and rendering to avoid deadlock in multi-gpu training
- Supporting synchronous normalization for multi-gpu training
2025-07-07 10:24:42 -07:00

814 lines
32 KiB
Python

import os
from typing import Optional
import torch
import torch.nn as nn
import torch.distributed as dist
from tensordict import TensorDict
class SimpleReplayBuffer(nn.Module):
def __init__(
self,
n_env: int,
buffer_size: int,
n_obs: int,
n_act: int,
n_critic_obs: int,
asymmetric_obs: bool = False,
playground_mode: bool = False,
n_steps: int = 1,
gamma: float = 0.99,
device=None,
):
"""
A simple replay buffer that stores transitions in a circular buffer.
Supports n-step returns and asymmetric observations.
When playground_mode=True, critic_observations are treated as a concatenation of
regular observations and privileged observations, and only the privileged part is stored
to save memory.
TODO (Younggyo): Refactor to split this into SimpleReplayBuffer and NStepReplayBuffer
"""
super().__init__()
self.n_env = n_env
self.buffer_size = buffer_size
self.n_obs = n_obs
self.n_act = n_act
self.n_critic_obs = n_critic_obs
self.asymmetric_obs = asymmetric_obs
self.playground_mode = playground_mode and asymmetric_obs
self.gamma = gamma
self.n_steps = n_steps
self.device = device
self.observations = torch.zeros(
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
)
self.actions = torch.zeros(
(n_env, buffer_size, n_act), device=device, dtype=torch.float
)
self.rewards = torch.zeros(
(n_env, buffer_size), device=device, dtype=torch.float
)
self.dones = torch.zeros((n_env, buffer_size), device=device, dtype=torch.long)
self.truncations = torch.zeros(
(n_env, buffer_size), device=device, dtype=torch.long
)
self.next_observations = torch.zeros(
(n_env, buffer_size, n_obs), device=device, dtype=torch.float
)
if asymmetric_obs:
if self.playground_mode:
# Only store the privileged part of observations (n_critic_obs - n_obs)
self.privileged_obs_size = n_critic_obs - n_obs
self.privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
)
self.next_privileged_observations = torch.zeros(
(n_env, buffer_size, self.privileged_obs_size),
device=device,
dtype=torch.float,
)
else:
# Store full critic observations
self.critic_observations = torch.zeros(
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
)
self.next_critic_observations = torch.zeros(
(n_env, buffer_size, n_critic_obs), device=device, dtype=torch.float
)
self.ptr = 0
@torch.no_grad()
def extend(
self,
tensor_dict: TensorDict,
):
observations = tensor_dict["observations"]
actions = tensor_dict["actions"]
rewards = tensor_dict["next"]["rewards"]
dones = tensor_dict["next"]["dones"]
truncations = tensor_dict["next"]["truncations"]
next_observations = tensor_dict["next"]["observations"]
ptr = self.ptr % self.buffer_size
self.observations[:, ptr] = observations
self.actions[:, ptr] = actions
self.rewards[:, ptr] = rewards
self.dones[:, ptr] = dones
self.truncations[:, ptr] = truncations
self.next_observations[:, ptr] = next_observations
if self.asymmetric_obs:
critic_observations = tensor_dict["critic_observations"]
next_critic_observations = tensor_dict["next"]["critic_observations"]
if self.playground_mode:
# Extract and store only the privileged part
privileged_observations = critic_observations[:, self.n_obs :]
next_privileged_observations = next_critic_observations[:, self.n_obs :]
self.privileged_observations[:, ptr] = privileged_observations
self.next_privileged_observations[:, ptr] = next_privileged_observations
else:
# Store full critic observations
self.critic_observations[:, ptr] = critic_observations
self.next_critic_observations[:, ptr] = next_critic_observations
self.ptr += 1
@torch.no_grad()
def sample(self, batch_size: int):
# we will sample n_env * batch_size transitions
if self.n_steps == 1:
indices = torch.randint(
0,
min(self.buffer_size, self.ptr),
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
observations = torch.gather(self.observations, 1, obs_indices).reshape(
self.n_env * batch_size, self.n_obs
)
next_observations = torch.gather(
self.next_observations, 1, obs_indices
).reshape(self.n_env * batch_size, self.n_obs)
actions = torch.gather(self.actions, 1, act_indices).reshape(
self.n_env * batch_size, self.n_act
)
rewards = torch.gather(self.rewards, 1, indices).reshape(
self.n_env * batch_size
)
dones = torch.gather(self.dones, 1, indices).reshape(
self.n_env * batch_size
)
truncations = torch.gather(self.truncations, 1, indices).reshape(
self.n_env * batch_size
)
effective_n_steps = torch.ones_like(dones)
if self.asymmetric_obs:
if self.playground_mode:
# Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size)
next_privileged_observations = torch.gather(
self.next_privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations
critic_observations = torch.cat(
[observations, privileged_observations], dim=1
)
next_critic_observations = torch.cat(
[next_observations, next_privileged_observations], dim=1
)
else:
# Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
)
critic_observations = torch.gather(
self.critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
next_critic_observations = torch.gather(
self.next_critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
else:
# Sample base indices
if self.ptr >= self.buffer_size:
# When the buffer is full, there is no protection against sampling across different episodes
# We avoid this by temporarily setting self.pos - 1 to truncated = True if not done
# https://github.com/DLR-RM/stable-baselines3/blob/b91050ca94f8bce7a0285c91f85da518d5a26223/stable_baselines3/common/buffers.py#L857-L860
# TODO (Younggyo): Change the reference when this SB3 branch is merged
current_pos = self.ptr % self.buffer_size
curr_truncations = self.truncations[:, current_pos - 1].clone()
self.truncations[:, current_pos - 1] = torch.logical_not(
self.dones[:, current_pos - 1]
)
indices = torch.randint(
0,
self.buffer_size,
(self.n_env, batch_size),
device=self.device,
)
else:
# Buffer not full - ensure n-step sequence doesn't exceed valid data
max_start_idx = max(1, self.ptr - self.n_steps + 1)
indices = torch.randint(
0,
max_start_idx,
(self.n_env, batch_size),
device=self.device,
)
obs_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
act_indices = indices.unsqueeze(-1).expand(-1, -1, self.n_act)
# Get base transitions
observations = torch.gather(self.observations, 1, obs_indices).reshape(
self.n_env * batch_size, self.n_obs
)
actions = torch.gather(self.actions, 1, act_indices).reshape(
self.n_env * batch_size, self.n_act
)
if self.asymmetric_obs:
if self.playground_mode:
# Gather privileged observations
priv_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
)
privileged_observations = torch.gather(
self.privileged_observations, 1, priv_obs_indices
).reshape(self.n_env * batch_size, self.privileged_obs_size)
# Concatenate with regular observations to form full critic observations
critic_observations = torch.cat(
[observations, privileged_observations], dim=1
)
else:
# Gather full critic observations
critic_obs_indices = indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
)
critic_observations = torch.gather(
self.critic_observations, 1, critic_obs_indices
).reshape(self.n_env * batch_size, self.n_critic_obs)
# Create sequential indices for each sample
# This creates a [n_env, batch_size, n_step] tensor of indices
seq_offsets = torch.arange(self.n_steps, device=self.device).view(1, 1, -1)
all_indices = (
indices.unsqueeze(-1) + seq_offsets
) % self.buffer_size # [n_env, batch_size, n_step]
# Gather all rewards and terminal flags
# Using advanced indexing - result shapes: [n_env, batch_size, n_step]
all_rewards = torch.gather(
self.rewards.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices
)
all_dones = torch.gather(
self.dones.unsqueeze(-1).expand(-1, -1, self.n_steps), 1, all_indices
)
all_truncations = torch.gather(
self.truncations.unsqueeze(-1).expand(-1, -1, self.n_steps),
1,
all_indices,
)
# Create masks for rewards *after* first done
# This creates a cumulative product that zeroes out rewards after the first done
all_dones_shifted = torch.cat(
[torch.zeros_like(all_dones[:, :, :1]), all_dones[:, :, :-1]], dim=2
) # First reward should not be masked
done_masks = torch.cumprod(
1.0 - all_dones_shifted, dim=2
) # [n_env, batch_size, n_step]
effective_n_steps = done_masks.sum(2)
# Create discount factors
discounts = torch.pow(
self.gamma, torch.arange(self.n_steps, device=self.device)
) # [n_steps]
# Apply masks and discounts to rewards
masked_rewards = all_rewards * done_masks # [n_env, batch_size, n_step]
discounted_rewards = masked_rewards * discounts.view(
1, 1, -1
) # [n_env, batch_size, n_step]
# Sum rewards along the n_step dimension
n_step_rewards = discounted_rewards.sum(dim=2) # [n_env, batch_size]
# Find index of first done or truncation or last step for each sequence
first_done = torch.argmax(
(all_dones > 0).float(), dim=2
) # [n_env, batch_size]
first_trunc = torch.argmax(
(all_truncations > 0).float(), dim=2
) # [n_env, batch_size]
# Handle case where there are no dones or truncations
no_dones = all_dones.sum(dim=2) == 0
no_truncs = all_truncations.sum(dim=2) == 0
# When no dones or truncs, use the last index
first_done = torch.where(no_dones, self.n_steps - 1, first_done)
first_trunc = torch.where(no_truncs, self.n_steps - 1, first_trunc)
# Take the minimum (first) of done or truncation
final_indices = torch.minimum(
first_done, first_trunc
) # [n_env, batch_size]
# Create indices to gather the final next observations
final_next_obs_indices = torch.gather(
all_indices, 2, final_indices.unsqueeze(-1)
).squeeze(
-1
) # [n_env, batch_size]
# Gather final values
final_next_observations = self.next_observations.gather(
1, final_next_obs_indices.unsqueeze(-1).expand(-1, -1, self.n_obs)
)
final_dones = self.dones.gather(1, final_next_obs_indices)
final_truncations = self.truncations.gather(1, final_next_obs_indices)
if self.asymmetric_obs:
if self.playground_mode:
# Gather final privileged observations
final_next_privileged_observations = (
self.next_privileged_observations.gather(
1,
final_next_obs_indices.unsqueeze(-1).expand(
-1, -1, self.privileged_obs_size
),
)
)
# Reshape for output
next_privileged_observations = (
final_next_privileged_observations.reshape(
self.n_env * batch_size, self.privileged_obs_size
)
)
# Concatenate with next observations to form full next critic observations
next_observations_reshaped = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs
)
next_critic_observations = torch.cat(
[next_observations_reshaped, next_privileged_observations],
dim=1,
)
else:
# Gather final next critic observations directly
final_next_critic_observations = (
self.next_critic_observations.gather(
1,
final_next_obs_indices.unsqueeze(-1).expand(
-1, -1, self.n_critic_obs
),
)
)
next_critic_observations = final_next_critic_observations.reshape(
self.n_env * batch_size, self.n_critic_obs
)
# Reshape everything to batch dimension
rewards = n_step_rewards.reshape(self.n_env * batch_size)
dones = final_dones.reshape(self.n_env * batch_size)
truncations = final_truncations.reshape(self.n_env * batch_size)
effective_n_steps = effective_n_steps.reshape(self.n_env * batch_size)
next_observations = final_next_observations.reshape(
self.n_env * batch_size, self.n_obs
)
out = TensorDict(
{
"observations": observations,
"actions": actions,
"next": {
"rewards": rewards,
"dones": dones,
"truncations": truncations,
"observations": next_observations,
"effective_n_steps": effective_n_steps,
},
},
batch_size=self.n_env * batch_size,
)
if self.asymmetric_obs:
out["critic_observations"] = critic_observations
out["next"]["critic_observations"] = next_critic_observations
if self.n_steps > 1 and self.ptr >= self.buffer_size:
# Roll back the truncation flags introduced for safe sampling
self.truncations[:, current_pos - 1] = curr_truncations
return out
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
def __init__(self, shape, device, eps=1e-2, until=None):
"""Initialize EmpiricalNormalization module.
Args:
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
exceeds it.
"""
super().__init__()
self.eps = eps
self.until = until
self.device = device
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0).to(device))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("count", torch.tensor(0, dtype=torch.long).to(device))
@property
def mean(self):
return self._mean.squeeze(0).clone()
@property
def std(self):
return self._std.squeeze(0).clone()
@torch.no_grad()
def forward(
self, x: torch.Tensor, center: bool = True, update: bool = True
) -> torch.Tensor:
if x.shape[1:] != self._mean.shape[1:]:
raise ValueError(
f"Expected input of shape (*,{self._mean.shape[1:]}), got {x.shape}"
)
if self.training and update:
self.update(x)
if center:
return (x - self._mean) / (self._std + self.eps)
else:
return x / (self._std + self.eps)
@torch.jit.unused
def update(self, x):
if self.until is not None and self.count >= self.until:
return
if dist.is_available() and dist.is_initialized():
# Calculate global batch size arithmetically
local_batch_size = x.shape[0]
world_size = dist.get_world_size()
global_batch_size = world_size * local_batch_size
# Calculate the stats
x_shifted = x - self._mean
local_sum_shifted = torch.sum(x_shifted, dim=0, keepdim=True)
local_sum_sq_shifted = torch.sum(x_shifted.pow(2), dim=0, keepdim=True)
# Sync the stats across all processes
stats_to_sync = torch.cat([local_sum_shifted, local_sum_sq_shifted], dim=0)
dist.all_reduce(stats_to_sync, op=dist.ReduceOp.SUM)
global_sum_shifted, global_sum_sq_shifted = stats_to_sync
# Calculate the mean and variance of the global batch
batch_mean_shifted = global_sum_shifted / global_batch_size
batch_var = (
global_sum_sq_shifted / global_batch_size - batch_mean_shifted.pow(2)
)
batch_mean = batch_mean_shifted + self._mean
else:
global_batch_size = x.shape[0]
batch_mean = torch.mean(x, dim=0, keepdim=True)
batch_var = torch.var(x, dim=0, keepdim=True, unbiased=False)
new_count = self.count + global_batch_size
# Update mean
delta = batch_mean - self._mean
self._mean.copy_(self._mean + delta * (global_batch_size / new_count))
# Update variance
delta2 = batch_mean - self._mean
m_a = self._var * self.count
m_b = batch_var * global_batch_size
M2 = m_a + m_b + delta2.pow(2) * (self.count * global_batch_size / new_count)
self._var.copy_(M2 / new_count)
self._std.copy_(self._var.sqrt())
self.count.copy_(new_count)
@torch.jit.unused
def inverse(self, y):
return y * (self._std + self.eps) + self._mean
class RewardNormalizer(nn.Module):
def __init__(
self,
gamma: float,
device: torch.device,
g_max: float = 10.0,
epsilon: float = 1e-8,
):
super().__init__()
self.register_buffer(
"G", torch.zeros(1, device=device)
) # running estimate of the discounted return
self.register_buffer("G_r_max", torch.zeros(1, device=device)) # running-max
self.G_rms = EmpiricalNormalization(shape=1, device=device)
self.gamma = gamma
self.g_max = g_max
self.epsilon = epsilon
def _scale_reward(self, rewards: torch.Tensor) -> torch.Tensor:
var_denominator = self.G_rms.std[0] + self.epsilon
min_required_denominator = self.G_r_max / self.g_max
denominator = torch.maximum(var_denominator, min_required_denominator)
return rewards / denominator
def update_stats(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
):
self.G = self.gamma * (1 - dones) * self.G + rewards
self.G_rms.update(self.G.view(-1, 1))
local_max = torch.max(torch.abs(self.G))
if dist.is_available() and dist.is_initialized():
dist.all_reduce(local_max, op=dist.ReduceOp.MAX)
self.G_r_max = max(self.G_r_max, local_max)
def forward(self, rewards: torch.Tensor) -> torch.Tensor:
return self._scale_reward(rewards)
class PerTaskEmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values for each task."""
def __init__(
self,
num_tasks: int,
shape: tuple,
device: torch.device,
eps: float = 1e-2,
until: int = None,
):
"""
Initialize PerTaskEmpiricalNormalization module.
Args:
num_tasks (int): The total number of tasks.
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If specified, learns until the sum of batch sizes
for a specific task exceeds this value.
"""
super().__init__()
if not isinstance(shape, tuple):
shape = (shape,)
self.num_tasks = num_tasks
self.shape = shape
self.eps = eps
self.until = until
self.device = device
# Buffers now have a leading dimension for tasks
self.register_buffer("_mean", torch.zeros(num_tasks, *shape).to(device))
self.register_buffer("_var", torch.ones(num_tasks, *shape).to(device))
self.register_buffer("_std", torch.ones(num_tasks, *shape).to(device))
self.register_buffer(
"count", torch.zeros(num_tasks, dtype=torch.long).to(device)
)
def forward(
self, x: torch.Tensor, task_ids: torch.Tensor, center: bool = True
) -> torch.Tensor:
"""
Normalize the input tensor `x` using statistics for the given `task_ids`.
Args:
x (torch.Tensor): Input tensor of shape [num_envs, *shape].
task_ids (torch.Tensor): Tensor of task indices, shape [num_envs].
center (bool): If True, center the data by subtracting the mean.
"""
if x.shape[1:] != self.shape:
raise ValueError(f"Expected input shape (*, {self.shape}), got {x.shape}")
if x.shape[0] != task_ids.shape[0]:
raise ValueError("Batch size of x and task_ids must match.")
# Gather the stats for the tasks in the current batch
# Reshape task_ids for broadcasting: [num_envs] -> [num_envs, 1, ...]
view_shape = (task_ids.shape[0],) + (1,) * len(self.shape)
task_ids_expanded = task_ids.view(view_shape).expand_as(x)
mean = self._mean.gather(0, task_ids_expanded)
std = self._std.gather(0, task_ids_expanded)
if self.training:
self.update(x, task_ids)
if center:
return (x - mean) / (std + self.eps)
else:
return x / (std + self.eps)
@torch.jit.unused
def update(self, x: torch.Tensor, task_ids: torch.Tensor):
"""Update running statistics for the tasks present in the batch."""
unique_tasks = torch.unique(task_ids)
for task_id in unique_tasks:
if self.until is not None and self.count[task_id] >= self.until:
continue
# Create a mask to select data for the current task
mask = task_ids == task_id
x_task = x[mask]
batch_size = x_task.shape[0]
if batch_size == 0:
continue
# Update count for this task
old_count = self.count[task_id].clone()
new_count = old_count + batch_size
# Update mean
task_mean = self._mean[task_id]
batch_mean = torch.mean(x_task, dim=0)
delta = batch_mean - task_mean
self._mean[task_id].copy_(task_mean + (batch_size / new_count) * delta)
# Update variance using Chan's parallel algorithm
if old_count > 0:
batch_var = torch.var(x_task, dim=0, unbiased=False)
m_a = self._var[task_id] * old_count
m_b = batch_var * batch_size
M2 = m_a + m_b + (delta**2) * (old_count * batch_size / new_count)
self._var[task_id].copy_(M2 / new_count)
else:
# For the first batch of this task
self._var[task_id].copy_(torch.var(x_task, dim=0, unbiased=False))
self._std[task_id].copy_(torch.sqrt(self._var[task_id]))
self.count[task_id].copy_(new_count)
class PerTaskRewardNormalizer(nn.Module):
def __init__(
self,
num_tasks: int,
gamma: float,
device: torch.device,
g_max: float = 10.0,
epsilon: float = 1e-8,
):
"""
Per-task reward normalizer, motivation comes from BRC (https://arxiv.org/abs/2505.23150v1)
"""
super().__init__()
self.num_tasks = num_tasks
self.gamma = gamma
self.g_max = g_max
self.epsilon = epsilon
self.device = device
# Per-task running estimate of the discounted return
self.register_buffer("G", torch.zeros(num_tasks, device=device))
# Per-task running-max of the discounted return
self.register_buffer("G_r_max", torch.zeros(num_tasks, device=device))
# Use the new per-task normalizer for the statistics of G
self.G_rms = PerTaskEmpiricalNormalization(
num_tasks=num_tasks, shape=(1,), device=device
)
def _scale_reward(
self, rewards: torch.Tensor, task_ids: torch.Tensor
) -> torch.Tensor:
"""
Scales rewards using per-task statistics.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
# Gather stats for the tasks in the batch
std_for_batch = self.G_rms._std.gather(0, task_ids.unsqueeze(-1)).squeeze(-1)
g_r_max_for_batch = self.G_r_max.gather(0, task_ids)
var_denominator = std_for_batch + self.epsilon
min_required_denominator = g_r_max_for_batch / self.g_max
denominator = torch.maximum(var_denominator, min_required_denominator)
# Add a small epsilon to the final denominator to prevent division by zero
# in case g_r_max is also zero.
return rewards / (denominator + self.epsilon)
def update_stats(
self, rewards: torch.Tensor, dones: torch.Tensor, task_ids: torch.Tensor
):
"""
Updates the running discounted return and its statistics for each task.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
dones (torch.Tensor): Done tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
if not (rewards.shape == dones.shape == task_ids.shape):
raise ValueError("rewards, dones, and task_ids must have the same shape.")
# === Update G (running discounted return) ===
# Gather the previous G values for the tasks in the batch
prev_G = self.G.gather(0, task_ids)
# Update G for each environment based on its own reward and done signal
new_G = self.gamma * (1 - dones.float()) * prev_G + rewards
# Scatter the updated G values back to the main buffer
self.G.scatter_(0, task_ids, new_G)
# === Update G_rms (statistics of G) ===
# The update function handles the per-task logic internally
self.G_rms.update(new_G.unsqueeze(-1), task_ids)
# === Update G_r_max (running max of |G|) ===
prev_G_r_max = self.G_r_max.gather(0, task_ids)
# Update the max for each environment
updated_G_r_max = torch.maximum(prev_G_r_max, torch.abs(new_G))
# Scatter the new maxes back to the main buffer
self.G_r_max.scatter_(0, task_ids, updated_G_r_max)
def forward(self, rewards: torch.Tensor, task_ids: torch.Tensor) -> torch.Tensor:
"""
Normalizes rewards. During training, it also updates the running statistics.
Args:
rewards (torch.Tensor): Reward tensor, shape [num_envs].
task_ids (torch.Tensor): Task indices, shape [num_envs].
"""
return self._scale_reward(rewards, task_ids)
def cpu_state(sd):
# detach & move to host without locking the compute stream
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
def save_params(
global_step,
actor,
qnet,
qnet_target,
obs_normalizer,
critic_obs_normalizer,
args,
save_path,
):
"""Save model parameters and training configuration to disk."""
def get_ddp_state_dict(model):
"""Get state dict from model, handling DDP wrapper if present."""
if hasattr(model, "module"):
return model.module.state_dict()
return model.state_dict()
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_dict = {
"actor_state_dict": cpu_state(get_ddp_state_dict(actor)),
"qnet_state_dict": cpu_state(get_ddp_state_dict(qnet)),
"qnet_target_state_dict": cpu_state(get_ddp_state_dict(qnet_target)),
"obs_normalizer_state": (
cpu_state(obs_normalizer.state_dict())
if hasattr(obs_normalizer, "state_dict")
else None
),
"critic_obs_normalizer_state": (
cpu_state(critic_obs_normalizer.state_dict())
if hasattr(critic_obs_normalizer, "state_dict")
else None
),
"args": vars(args), # Save all arguments
"global_step": global_step,
}
torch.save(save_dict, save_path, _use_new_zipfile_serialization=True)
print(f"Saved parameters and configuration to {save_path}")
def get_ddp_state_dict(model):
"""Get state dict from model, handling DDP wrapper if present."""
if hasattr(model, "module"):
return model.module.state_dict()
return model.state_dict()
def load_ddp_state_dict(model, state_dict):
"""Load state dict into model, handling DDP wrapper if present."""
if hasattr(model, "module"):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
@torch.no_grad()
def mark_step():
# call this once per iteration *before* any compiled function
torch.compiler.cudagraph_mark_step_begin()