Support FastTD3 + SimbaV2 (#13)
- Support hyperspherical normalization - Support loading FastTD3 + SimbaV2 for both training and inference - Support (experimental) reward normalization that uses SimbaV2's formulation -- not working that well though - Updated README for FastTD3 + SimbaV2
This commit is contained in:
parent
1014bf7e82
commit
6e890eebd2
12
README.md
12
README.md
@ -10,6 +10,8 @@ For more information, please see our [project webpage](https://younggyo.me/fast_
|
|||||||
|
|
||||||
|
|
||||||
## ❗ Updates
|
## ❗ Updates
|
||||||
|
- **[June/15/2025]** Added support for FastTD3 + SimbaV2! It's faster to train, and often achieves better asymptotic performance.
|
||||||
|
|
||||||
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
- **[Jun/6/2025]** Thanks to [Antonin Raffin](https://araffin.github.io/) ([@araffin](https://github.com/araffin)), we fixed the issues when using `n_steps` > 1, which stabilizes training with n-step return quite a lot!
|
||||||
|
|
||||||
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
- **[Jun/1/2025]** Updated the figures in the technical report to report deterministic evaluation for IsaacLab tasks.
|
||||||
@ -99,21 +101,29 @@ Please see `fast_td3/hyperparams.py` for information regarding hyperparameters!
|
|||||||
### HumanoidBench Experiments
|
### HumanoidBench Experiments
|
||||||
```bash
|
```bash
|
||||||
conda activate fasttd3_hb
|
conda activate fasttd3_hb
|
||||||
|
# FastTD3
|
||||||
python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --seed 1
|
python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --seed 1
|
||||||
|
# FastTD3 + SimbaV2
|
||||||
|
python fast_td3/train.py --env_name h1hand-hurdle-v0 --exp_name FastTD3 --render_interval 5000 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### MuJoCo Playground Experiments
|
### MuJoCo Playground Experiments
|
||||||
```bash
|
```bash
|
||||||
conda activate fasttd3_playground
|
conda activate fasttd3_playground
|
||||||
|
# FastTD3
|
||||||
python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
||||||
python fast_td3/train.py --env_name G1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
python fast_td3/train.py --env_name G1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --seed 1
|
||||||
|
# FastTD3 + SimbaV2
|
||||||
|
python fast_td3/train.py --env_name T1JoystickFlatTerrain --exp_name FastTD3 --render_interval 5000 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1
|
||||||
```
|
```
|
||||||
|
|
||||||
### IsaacLab Experiments
|
### IsaacLab Experiments
|
||||||
```bash
|
```bash
|
||||||
conda activate fasttd3_isaaclab
|
conda activate fasttd3_isaaclab
|
||||||
|
# FastTD3
|
||||||
python fast_td3/train.py --env_name Isaac-Velocity-Flat-G1-v0 --exp_name FastTD3 --render_interval 0 --seed 1
|
python fast_td3/train.py --env_name Isaac-Velocity-Flat-G1-v0 --exp_name FastTD3 --render_interval 0 --seed 1
|
||||||
python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --seed 1
|
# FastTD3 + SimbaV2
|
||||||
|
python fast_td3/train.py --env_name Isaac-Repose-Cube-Allegro-Direct-v0 --exp_name FastTD3 --render_interval 0 --agent fasttd3_simbav2 --batch_size 8192 --critic_learning_rate_end 3e-5 --actor_learning_rate_end 3e-5 --weight_decay 0.0 --critic_hidden_dim 512 --critic_num_blocks 2 --actor_hidden_dim 256 --actor_num_blocks 1 --seed 1
|
||||||
```
|
```
|
||||||
|
|
||||||
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
|
**Quick note:** For boolean-based arguments, you can set them to False by adding `no_` in front each argument, for instance, if you want to disable Clipped Q Learning, you can specify `--no_use_cdq` in your command.
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from .fast_td3_utils import EmpiricalNormalization
|
from .fast_td3_utils import EmpiricalNormalization
|
||||||
from .fast_td3 import Actor
|
from .fast_td3 import Actor
|
||||||
|
from .fast_td3_simbav2 import Actor as ActorSimbaV2
|
||||||
|
|
||||||
|
|
||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
@ -9,12 +12,18 @@ class Policy(nn.Module):
|
|||||||
self,
|
self,
|
||||||
n_obs: int,
|
n_obs: int,
|
||||||
n_act: int,
|
n_act: int,
|
||||||
num_envs: int,
|
args: dict,
|
||||||
init_scale: float,
|
agent: str = "fasttd3",
|
||||||
actor_hidden_dim: int,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.actor = Actor(
|
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
num_envs = args["num_envs"]
|
||||||
|
init_scale = args["init_scale"]
|
||||||
|
actor_hidden_dim = args["actor_hidden_dim"]
|
||||||
|
|
||||||
|
actor_kwargs = dict(
|
||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
n_act=n_act,
|
n_act=n_act,
|
||||||
num_envs=num_envs,
|
num_envs=num_envs,
|
||||||
@ -22,6 +31,31 @@ class Policy(nn.Module):
|
|||||||
init_scale=init_scale,
|
init_scale=init_scale,
|
||||||
hidden_dim=actor_hidden_dim,
|
hidden_dim=actor_hidden_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if agent == "fasttd3":
|
||||||
|
actor_cls = Actor
|
||||||
|
elif agent == "fasttd3_simbav2":
|
||||||
|
actor_cls = ActorSimbaV2
|
||||||
|
|
||||||
|
actor_num_blocks = args["actor_num_blocks"]
|
||||||
|
actor_kwargs.pop("init_scale")
|
||||||
|
actor_kwargs.update(
|
||||||
|
{
|
||||||
|
"scaler_init": math.sqrt(2.0 / actor_hidden_dim),
|
||||||
|
"scaler_scale": math.sqrt(2.0 / actor_hidden_dim),
|
||||||
|
"alpha_init": 1.0 / (actor_num_blocks + 1),
|
||||||
|
"alpha_scale": 1.0 / math.sqrt(actor_hidden_dim),
|
||||||
|
"expansion": 4,
|
||||||
|
"c_shift": 3.0,
|
||||||
|
"num_blocks": actor_num_blocks,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent {agent} not supported")
|
||||||
|
|
||||||
|
self.actor = actor_cls(
|
||||||
|
**actor_kwargs,
|
||||||
|
)
|
||||||
self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu")
|
self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu")
|
||||||
|
|
||||||
self.actor.eval()
|
self.actor.eval()
|
||||||
@ -45,17 +79,25 @@ def load_policy(checkpoint_path):
|
|||||||
)
|
)
|
||||||
args = torch_checkpoint["args"]
|
args = torch_checkpoint["args"]
|
||||||
|
|
||||||
n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1]
|
agent = args.get("agent", "fasttd3")
|
||||||
n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0]
|
if agent == "fasttd3":
|
||||||
|
n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1]
|
||||||
|
n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0]
|
||||||
|
elif agent == "fasttd3_simbav2":
|
||||||
|
# TODO: Too hard-coded, maybe save n_obs and n_act in the checkpoint?
|
||||||
|
n_obs = (
|
||||||
|
torch_checkpoint["actor_state_dict"]["embedder.w.w.weight"].shape[-1] - 1
|
||||||
|
)
|
||||||
|
n_act = torch_checkpoint["actor_state_dict"]["predictor.mean_bias"].shape[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent {agent} not supported")
|
||||||
|
|
||||||
policy = Policy(
|
policy = Policy(
|
||||||
n_obs=n_obs,
|
n_obs=n_obs,
|
||||||
n_act=n_act,
|
n_act=n_act,
|
||||||
num_envs=args["num_envs"],
|
args=args,
|
||||||
init_scale=args["init_scale"],
|
agent=agent,
|
||||||
actor_hidden_dim=args["actor_hidden_dim"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"])
|
policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"])
|
||||||
|
|
||||||
if len(torch_checkpoint["obs_normalizer_state"]) == 0:
|
if len(torch_checkpoint["obs_normalizer_state"]) == 0:
|
||||||
|
494
fast_td3/fast_td3_simbav2.py
Normal file
494
fast_td3/fast_td3_simbav2.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def l2normalize(
|
||||||
|
tensor: torch.Tensor, axis: int = -1, eps: float = 1e-8
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Computes L2 normalization of a tensor."""
|
||||||
|
return tensor / (torch.linalg.norm(tensor, ord=2, dim=axis, keepdim=True) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Scaler(nn.Module):
|
||||||
|
"""
|
||||||
|
A learnable scaling layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
init: float = 1.0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scaler = nn.Parameter(torch.full((dim,), init * scale, device=device))
|
||||||
|
self.forward_scaler = init / scale
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.scaler * self.forward_scaler * x
|
||||||
|
|
||||||
|
|
||||||
|
class HyperDense(nn.Module):
|
||||||
|
"""
|
||||||
|
A dense layer without bias and with orthogonal initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, device: torch.device = None):
|
||||||
|
super().__init__()
|
||||||
|
self.w = nn.Linear(in_dim, hidden_dim, bias=False, device=device)
|
||||||
|
nn.init.orthogonal_(self.w.weight, gain=1.0)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.w(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HyperMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
A small MLP with a specific architecture using HyperDense and Scaler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
eps: float = 1e-8,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = HyperDense(in_dim, hidden_dim, device=device)
|
||||||
|
self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device)
|
||||||
|
self.w2 = HyperDense(hidden_dim, out_dim, device=device)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.w1(x)
|
||||||
|
x = self.scaler(x)
|
||||||
|
# `eps` is required to prevent zero vector.
|
||||||
|
x = F.relu(x) + self.eps
|
||||||
|
x = self.w2(x)
|
||||||
|
x = l2normalize(x, axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HyperEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds input by concatenating a constant, normalizing, and applying layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
c_shift: float,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# The input dimension to the dense layer is in_dim + 1
|
||||||
|
self.w = HyperDense(in_dim + 1, hidden_dim, device=device)
|
||||||
|
self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device)
|
||||||
|
self.c_shift = c_shift
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
new_axis = torch.full((*x.shape[:-1], 1), self.c_shift, device=x.device)
|
||||||
|
x = torch.cat([x, new_axis], dim=-1)
|
||||||
|
x = l2normalize(x, axis=-1)
|
||||||
|
x = self.w(x)
|
||||||
|
x = self.scaler(x)
|
||||||
|
x = l2normalize(x, axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HyperLERPBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A residual block using Linear Interpolation (LERP).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
alpha_init: float,
|
||||||
|
alpha_scale: float,
|
||||||
|
expansion: int = 4,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = HyperMLP(
|
||||||
|
in_dim=hidden_dim,
|
||||||
|
hidden_dim=hidden_dim * expansion,
|
||||||
|
out_dim=hidden_dim,
|
||||||
|
scaler_init=scaler_init / math.sqrt(expansion),
|
||||||
|
scaler_scale=scaler_scale / math.sqrt(expansion),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.alpha_scaler = Scaler(
|
||||||
|
dim=hidden_dim,
|
||||||
|
init=alpha_init,
|
||||||
|
scale=alpha_scale,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
mlp_out = self.mlp(x)
|
||||||
|
# The original paper uses (x - residual) but x is the residual here.
|
||||||
|
# This is interpreted as alpha * (mlp_output - residual_input)
|
||||||
|
x = residual + self.alpha_scaler(mlp_out - residual)
|
||||||
|
x = l2normalize(x, axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class HyperTanhPolicy(nn.Module):
|
||||||
|
"""
|
||||||
|
A policy that outputs a Tanh action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
action_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.mean_w1 = HyperDense(hidden_dim, hidden_dim, device=device)
|
||||||
|
self.mean_scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device)
|
||||||
|
self.mean_w2 = HyperDense(hidden_dim, action_dim, device=device)
|
||||||
|
self.mean_bias = nn.Parameter(torch.zeros(action_dim, device=device))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Mean path
|
||||||
|
mean = self.mean_w1(x)
|
||||||
|
mean = self.mean_scaler(mean)
|
||||||
|
mean = self.mean_w2(mean) + self.mean_bias
|
||||||
|
mean = torch.tanh(mean)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
class HyperCategoricalValue(nn.Module):
|
||||||
|
"""
|
||||||
|
A value function that predicts a categorical distribution over a range of values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
num_bins: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = HyperDense(hidden_dim, hidden_dim, device=device)
|
||||||
|
self.scaler = Scaler(hidden_dim, scaler_init, scaler_scale, device=device)
|
||||||
|
self.w2 = HyperDense(hidden_dim, num_bins, device=device)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_bins, device=device))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
logits = self.w1(x)
|
||||||
|
logits = self.scaler(logits)
|
||||||
|
logits = self.w2(logits) + self.bias
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionalQNetwork(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_obs: int,
|
||||||
|
n_act: int,
|
||||||
|
num_atoms: int,
|
||||||
|
v_min: float,
|
||||||
|
v_max: float,
|
||||||
|
hidden_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
alpha_init: float,
|
||||||
|
alpha_scale: float,
|
||||||
|
num_blocks: int,
|
||||||
|
c_shift: float,
|
||||||
|
expansion: int,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedder = HyperEmbedder(
|
||||||
|
in_dim=n_obs + n_act,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
c_shift=c_shift,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
*[
|
||||||
|
HyperLERPBlock(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
alpha_init=alpha_init,
|
||||||
|
alpha_scale=alpha_scale,
|
||||||
|
expansion=expansion,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.predictor = HyperCategoricalValue(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
num_bins=num_atoms,
|
||||||
|
scaler_init=1.0,
|
||||||
|
scaler_scale=1.0,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.v_min = v_min
|
||||||
|
self.v_max = v_max
|
||||||
|
self.num_atoms = num_atoms
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = torch.cat([obs, actions], 1)
|
||||||
|
x = self.embedder(x)
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.predictor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def projection(
|
||||||
|
self,
|
||||||
|
obs: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
bootstrap: torch.Tensor,
|
||||||
|
discount: float,
|
||||||
|
q_support: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
|
||||||
|
batch_size = rewards.shape[0]
|
||||||
|
|
||||||
|
target_z = (
|
||||||
|
rewards.unsqueeze(1)
|
||||||
|
+ bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
|
||||||
|
)
|
||||||
|
target_z = target_z.clamp(self.v_min, self.v_max)
|
||||||
|
b = (target_z - self.v_min) / delta_z
|
||||||
|
l = torch.floor(b).long()
|
||||||
|
u = torch.ceil(b).long()
|
||||||
|
|
||||||
|
l_mask = torch.logical_and((u > 0), (l == u))
|
||||||
|
u_mask = torch.logical_and((l < (self.num_atoms - 1)), (l == u))
|
||||||
|
|
||||||
|
l = torch.where(l_mask, l - 1, l)
|
||||||
|
u = torch.where(u_mask, u + 1, u)
|
||||||
|
|
||||||
|
next_dist = F.softmax(self.forward(obs, actions), dim=1)
|
||||||
|
proj_dist = torch.zeros_like(next_dist)
|
||||||
|
offset = (
|
||||||
|
torch.linspace(
|
||||||
|
0, (batch_size - 1) * self.num_atoms, batch_size, device=device
|
||||||
|
)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.expand(batch_size, self.num_atoms)
|
||||||
|
.long()
|
||||||
|
)
|
||||||
|
proj_dist.view(-1).index_add_(
|
||||||
|
0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
|
||||||
|
)
|
||||||
|
proj_dist.view(-1).index_add_(
|
||||||
|
0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
|
||||||
|
)
|
||||||
|
return proj_dist
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_obs: int,
|
||||||
|
n_act: int,
|
||||||
|
num_atoms: int,
|
||||||
|
v_min: float,
|
||||||
|
v_max: float,
|
||||||
|
hidden_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
alpha_init: float,
|
||||||
|
alpha_scale: float,
|
||||||
|
num_blocks: int,
|
||||||
|
c_shift: float,
|
||||||
|
expansion: int,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.qnet1 = DistributionalQNetwork(
|
||||||
|
n_obs=n_obs,
|
||||||
|
n_act=n_act,
|
||||||
|
num_atoms=num_atoms,
|
||||||
|
v_min=v_min,
|
||||||
|
v_max=v_max,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
alpha_init=alpha_init,
|
||||||
|
alpha_scale=alpha_scale,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
c_shift=c_shift,
|
||||||
|
expansion=expansion,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.qnet2 = DistributionalQNetwork(
|
||||||
|
n_obs=n_obs,
|
||||||
|
n_act=n_act,
|
||||||
|
num_atoms=num_atoms,
|
||||||
|
v_min=v_min,
|
||||||
|
v_max=v_max,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
alpha_init=alpha_init,
|
||||||
|
alpha_scale=alpha_scale,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
c_shift=c_shift,
|
||||||
|
expansion=expansion,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"q_support", torch.linspace(v_min, v_max, num_atoms, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.qnet1(obs, actions), self.qnet2(obs, actions)
|
||||||
|
|
||||||
|
def projection(
|
||||||
|
self,
|
||||||
|
obs: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
bootstrap: torch.Tensor,
|
||||||
|
discount: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Projection operation that includes q_support directly"""
|
||||||
|
q1_proj = self.qnet1.projection(
|
||||||
|
obs,
|
||||||
|
actions,
|
||||||
|
rewards,
|
||||||
|
bootstrap,
|
||||||
|
discount,
|
||||||
|
self.q_support,
|
||||||
|
self.q_support.device,
|
||||||
|
)
|
||||||
|
q2_proj = self.qnet2.projection(
|
||||||
|
obs,
|
||||||
|
actions,
|
||||||
|
rewards,
|
||||||
|
bootstrap,
|
||||||
|
discount,
|
||||||
|
self.q_support,
|
||||||
|
self.q_support.device,
|
||||||
|
)
|
||||||
|
return q1_proj, q2_proj
|
||||||
|
|
||||||
|
def get_value(self, probs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Calculate value from logits using support"""
|
||||||
|
return torch.sum(probs * self.q_support, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_obs: int,
|
||||||
|
n_act: int,
|
||||||
|
num_envs: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
scaler_init: float,
|
||||||
|
scaler_scale: float,
|
||||||
|
alpha_init: float,
|
||||||
|
alpha_scale: float,
|
||||||
|
expansion: int,
|
||||||
|
c_shift: float,
|
||||||
|
num_blocks: int,
|
||||||
|
std_min: float = 0.05,
|
||||||
|
std_max: float = 0.8,
|
||||||
|
device: torch.device = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.n_act = n_act
|
||||||
|
|
||||||
|
self.embedder = HyperEmbedder(
|
||||||
|
in_dim=n_obs,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
c_shift=c_shift,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
*[
|
||||||
|
HyperLERPBlock(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
alpha_init=alpha_init,
|
||||||
|
alpha_scale=alpha_scale,
|
||||||
|
expansion=expansion,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.predictor = HyperTanhPolicy(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
action_dim=n_act,
|
||||||
|
scaler_init=scaler_init,
|
||||||
|
scaler_scale=scaler_scale,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_scales = (
|
||||||
|
torch.rand(num_envs, 1, device=device) * (std_max - std_min) + std_min
|
||||||
|
)
|
||||||
|
self.register_buffer("noise_scales", noise_scales)
|
||||||
|
|
||||||
|
self.register_buffer("std_min", torch.as_tensor(std_min, device=device))
|
||||||
|
self.register_buffer("std_max", torch.as_tensor(std_max, device=device))
|
||||||
|
self.n_envs = num_envs
|
||||||
|
|
||||||
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = obs
|
||||||
|
x = self.embedder(x)
|
||||||
|
x = self.encoder(x)
|
||||||
|
x = self.predictor(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def explore(
|
||||||
|
self, obs: torch.Tensor, dones: torch.Tensor = None, deterministic: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# If dones is provided, resample noise for environments that are done
|
||||||
|
if dones is not None and dones.sum() > 0:
|
||||||
|
# Generate new noise scales for done environments (one per environment)
|
||||||
|
new_scales = (
|
||||||
|
torch.rand(self.n_envs, 1, device=obs.device)
|
||||||
|
* (self.std_max - self.std_min)
|
||||||
|
+ self.std_min
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update only the noise scales for environments that are done
|
||||||
|
dones_view = dones.view(-1, 1) > 0
|
||||||
|
self.noise_scales = torch.where(dones_view, new_scales, self.noise_scales)
|
||||||
|
|
||||||
|
act = self(obs)
|
||||||
|
if deterministic:
|
||||||
|
return act
|
||||||
|
|
||||||
|
noise = torch.randn_like(act) * self.noise_scales
|
||||||
|
return act + noise
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -472,6 +474,44 @@ class EmpiricalNormalization(nn.Module):
|
|||||||
return y * (self._std + self.eps) + self._mean
|
return y * (self._std + self.eps) + self._mean
|
||||||
|
|
||||||
|
|
||||||
|
class RewardNormalizer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gamma: float,
|
||||||
|
device: torch.device,
|
||||||
|
g_max: float = 10.0,
|
||||||
|
epsilon: float = 1e-8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer(
|
||||||
|
"G", torch.zeros(1, device=device)
|
||||||
|
) # running estimate of the discounted return
|
||||||
|
self.register_buffer("G_r_max", torch.zeros(1, device=device)) # running-max
|
||||||
|
self.G_rms = EmpiricalNormalization(shape=1, device=device)
|
||||||
|
self.gamma = gamma
|
||||||
|
self.g_max = g_max
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def _scale_reward(self, rewards: torch.Tensor) -> torch.Tensor:
|
||||||
|
var_denominator = self.G_rms.std[0] + self.epsilon
|
||||||
|
min_required_denominator = self.G_r_max / self.g_max
|
||||||
|
denominator = torch.maximum(var_denominator, min_required_denominator)
|
||||||
|
|
||||||
|
return rewards / denominator
|
||||||
|
|
||||||
|
def update_stats(
|
||||||
|
self,
|
||||||
|
rewards: torch.Tensor,
|
||||||
|
dones: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.G = self.gamma * (1 - dones) * self.G + rewards
|
||||||
|
self.G_rms.update(self.G.view(-1, 1))
|
||||||
|
self.G_r_max = max(self.G_r_max, max(abs(self.G)))
|
||||||
|
|
||||||
|
def forward(self, rewards: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self._scale_reward(rewards)
|
||||||
|
|
||||||
|
|
||||||
def cpu_state(sd):
|
def cpu_state(sd):
|
||||||
# detach & move to host without locking the compute stream
|
# detach & move to host without locking the compute stream
|
||||||
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
|
return {k: v.detach().to("cpu", non_blocking=True) for k, v in sd.items()}
|
||||||
|
@ -10,6 +10,8 @@ class BaseArgs:
|
|||||||
# See IsaacLabArgs for default hyperparameters for IsaacLab
|
# See IsaacLabArgs for default hyperparameters for IsaacLab
|
||||||
env_name: str = "h1hand-stand-v0"
|
env_name: str = "h1hand-stand-v0"
|
||||||
"""the id of the environment"""
|
"""the id of the environment"""
|
||||||
|
agent: str = "fasttd3"
|
||||||
|
"""the agent to use: currently support [fasttd3, fasttd3_simbav2]"""
|
||||||
seed: int = 1
|
seed: int = 1
|
||||||
"""seed of the experiment"""
|
"""seed of the experiment"""
|
||||||
torch_deterministic: bool = True
|
torch_deterministic: bool = True
|
||||||
@ -36,6 +38,10 @@ class BaseArgs:
|
|||||||
"""the learning rate of the critic"""
|
"""the learning rate of the critic"""
|
||||||
actor_learning_rate: float = 3e-4
|
actor_learning_rate: float = 3e-4
|
||||||
"""the learning rate for the actor"""
|
"""the learning rate for the actor"""
|
||||||
|
critic_learning_rate_end: float = 3e-4
|
||||||
|
"""the learning rate of the critic at the end of training"""
|
||||||
|
actor_learning_rate_end: float = 3e-4
|
||||||
|
"""the learning rate for the actor at the end of training"""
|
||||||
buffer_size: int = 1024 * 50
|
buffer_size: int = 1024 * 50
|
||||||
"""the replay memory buffer size"""
|
"""the replay memory buffer size"""
|
||||||
num_steps: int = 1
|
num_steps: int = 1
|
||||||
@ -72,6 +78,10 @@ class BaseArgs:
|
|||||||
"""the hidden dimension of the critic network"""
|
"""the hidden dimension of the critic network"""
|
||||||
actor_hidden_dim: int = 512
|
actor_hidden_dim: int = 512
|
||||||
"""the hidden dimension of the actor network"""
|
"""the hidden dimension of the actor network"""
|
||||||
|
critic_num_blocks: int = 2
|
||||||
|
"""(SimbaV2 only) the number of blocks in the critic network"""
|
||||||
|
actor_num_blocks: int = 1
|
||||||
|
"""(SimbaV2 only) the number of blocks in the actor network"""
|
||||||
use_cdq: bool = True
|
use_cdq: bool = True
|
||||||
"""whether to use Clipped Double Q-learning"""
|
"""whether to use Clipped Double Q-learning"""
|
||||||
measure_burnin: int = 3
|
measure_burnin: int = 3
|
||||||
@ -84,6 +94,8 @@ class BaseArgs:
|
|||||||
"""whether to use torch.compile."""
|
"""whether to use torch.compile."""
|
||||||
obs_normalization: bool = True
|
obs_normalization: bool = True
|
||||||
"""whether to enable observation normalization"""
|
"""whether to enable observation normalization"""
|
||||||
|
reward_normalization: bool = False
|
||||||
|
"""whether to enable reward normalization (Not recommended for now, it's unstable.)"""
|
||||||
max_grad_norm: float = 0.0
|
max_grad_norm: float = 0.0
|
||||||
"""the maximum gradient norm"""
|
"""the maximum gradient norm"""
|
||||||
amp: bool = True
|
amp: bool = True
|
||||||
@ -350,6 +362,7 @@ class Go1GetupArgs(MuJoCoPlaygroundArgs):
|
|||||||
class LeapCubeReorientArgs(MuJoCoPlaygroundArgs):
|
class LeapCubeReorientArgs(MuJoCoPlaygroundArgs):
|
||||||
env_name: str = "LeapCubeReorient"
|
env_name: str = "LeapCubeReorient"
|
||||||
num_steps: int = 3
|
num_steps: int = 3
|
||||||
|
gamma: float = 0.99
|
||||||
policy_noise: float = 0.2
|
policy_noise: float = 0.2
|
||||||
v_min: float = -50.0
|
v_min: float = -50.0
|
||||||
v_max: float = 50.0
|
v_max: float = 50.0
|
||||||
@ -361,6 +374,7 @@ class LeapCubeRotateZAxisArgs(MuJoCoPlaygroundArgs):
|
|||||||
env_name: str = "LeapCubeRotateZAxis"
|
env_name: str = "LeapCubeRotateZAxis"
|
||||||
num_steps: int = 1
|
num_steps: int = 1
|
||||||
policy_noise: float = 0.2
|
policy_noise: float = 0.2
|
||||||
|
gamma: float = 0.99
|
||||||
v_min: float = -10.0
|
v_min: float = -10.0
|
||||||
v_max: float = 10.0
|
v_max: float = 10.0
|
||||||
use_cdq: bool = False
|
use_cdq: bool = False
|
||||||
|
@ -12,6 +12,7 @@ os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest"
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import math
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import wandb
|
import wandb
|
||||||
@ -25,9 +26,13 @@ from torch.amp import autocast, GradScaler
|
|||||||
|
|
||||||
from tensordict import TensorDict, from_module
|
from tensordict import TensorDict, from_module
|
||||||
|
|
||||||
from fast_td3_utils import EmpiricalNormalization, SimpleReplayBuffer, save_params
|
from fast_td3_utils import (
|
||||||
|
EmpiricalNormalization,
|
||||||
|
RewardNormalizer,
|
||||||
|
SimpleReplayBuffer,
|
||||||
|
save_params,
|
||||||
|
)
|
||||||
from hyperparams import get_args
|
from hyperparams import get_args
|
||||||
from fast_td3 import Actor, Critic
|
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("high")
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
@ -135,44 +140,74 @@ def main():
|
|||||||
obs_normalizer = nn.Identity()
|
obs_normalizer = nn.Identity()
|
||||||
critic_obs_normalizer = nn.Identity()
|
critic_obs_normalizer = nn.Identity()
|
||||||
|
|
||||||
actor = Actor(
|
if args.reward_normalization:
|
||||||
n_obs=n_obs,
|
reward_normalizer = RewardNormalizer(
|
||||||
n_act=n_act,
|
gamma=args.gamma, device=device, g_max=min(abs(args.v_min), abs(args.v_max))
|
||||||
num_envs=args.num_envs,
|
)
|
||||||
device=device,
|
else:
|
||||||
init_scale=args.init_scale,
|
reward_normalizer = nn.Identity()
|
||||||
hidden_dim=args.actor_hidden_dim,
|
|
||||||
)
|
actor_kwargs = {
|
||||||
actor_detach = Actor(
|
"n_obs": n_obs,
|
||||||
n_obs=n_obs,
|
"n_act": n_act,
|
||||||
n_act=n_act,
|
"num_envs": args.num_envs,
|
||||||
num_envs=args.num_envs,
|
"device": device,
|
||||||
device=device,
|
"init_scale": args.init_scale,
|
||||||
init_scale=args.init_scale,
|
"hidden_dim": args.actor_hidden_dim,
|
||||||
hidden_dim=args.actor_hidden_dim,
|
}
|
||||||
)
|
critic_kwargs = {
|
||||||
|
"n_obs": n_critic_obs,
|
||||||
|
"n_act": n_act,
|
||||||
|
"num_atoms": args.num_atoms,
|
||||||
|
"v_min": args.v_min,
|
||||||
|
"v_max": args.v_max,
|
||||||
|
"hidden_dim": args.critic_hidden_dim,
|
||||||
|
"device": device,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.agent == "fasttd3":
|
||||||
|
from fast_td3 import Actor, Critic
|
||||||
|
|
||||||
|
print("Using FastTD3")
|
||||||
|
elif args.agent == "fasttd3_simbav2":
|
||||||
|
from fast_td3_simbav2 import Actor, Critic
|
||||||
|
|
||||||
|
print("Using FastTD3 + SimbaV2")
|
||||||
|
actor_kwargs.pop("init_scale")
|
||||||
|
actor_kwargs.update(
|
||||||
|
{
|
||||||
|
"scaler_init": math.sqrt(2.0 / args.actor_hidden_dim),
|
||||||
|
"scaler_scale": math.sqrt(2.0 / args.actor_hidden_dim),
|
||||||
|
"alpha_init": 1.0 / (args.actor_num_blocks + 1),
|
||||||
|
"alpha_scale": 1.0 / math.sqrt(args.actor_hidden_dim),
|
||||||
|
"expansion": 4,
|
||||||
|
"c_shift": 3.0,
|
||||||
|
"num_blocks": args.actor_num_blocks,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
critic_kwargs.update(
|
||||||
|
{
|
||||||
|
"scaler_init": math.sqrt(2.0 / args.critic_hidden_dim),
|
||||||
|
"scaler_scale": math.sqrt(2.0 / args.critic_hidden_dim),
|
||||||
|
"alpha_init": 1.0 / (args.critic_num_blocks + 1),
|
||||||
|
"alpha_scale": 1.0 / math.sqrt(args.critic_hidden_dim),
|
||||||
|
"num_blocks": args.critic_num_blocks,
|
||||||
|
"expansion": 4,
|
||||||
|
"c_shift": 3.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent {args.agent} not supported")
|
||||||
|
|
||||||
|
actor = Actor(**actor_kwargs)
|
||||||
|
actor_detach = Actor(**actor_kwargs)
|
||||||
|
|
||||||
# Copy params to actor_detach without grad
|
# Copy params to actor_detach without grad
|
||||||
from_module(actor).data.to_module(actor_detach)
|
from_module(actor).data.to_module(actor_detach)
|
||||||
policy = actor_detach.explore
|
policy = actor_detach.explore
|
||||||
|
|
||||||
qnet = Critic(
|
qnet = Critic(**critic_kwargs)
|
||||||
n_obs=n_critic_obs,
|
qnet_target = Critic(**critic_kwargs)
|
||||||
n_act=n_act,
|
|
||||||
num_atoms=args.num_atoms,
|
|
||||||
v_min=args.v_min,
|
|
||||||
v_max=args.v_max,
|
|
||||||
hidden_dim=args.critic_hidden_dim,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
qnet_target = Critic(
|
|
||||||
n_obs=n_critic_obs,
|
|
||||||
n_act=n_act,
|
|
||||||
num_atoms=args.num_atoms,
|
|
||||||
v_min=args.v_min,
|
|
||||||
v_max=args.v_max,
|
|
||||||
hidden_dim=args.critic_hidden_dim,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
qnet_target.load_state_dict(qnet.state_dict())
|
qnet_target.load_state_dict(qnet.state_dict())
|
||||||
|
|
||||||
q_optimizer = optim.AdamW(
|
q_optimizer = optim.AdamW(
|
||||||
@ -186,6 +221,18 @@ def main():
|
|||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add learning rate schedulers
|
||||||
|
q_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
q_optimizer,
|
||||||
|
T_max=args.total_timesteps,
|
||||||
|
eta_min=args.critic_learning_rate_end, # Decay to 10% of initial lr
|
||||||
|
)
|
||||||
|
actor_scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||||
|
actor_optimizer,
|
||||||
|
T_max=args.total_timesteps,
|
||||||
|
eta_min=args.actor_learning_rate_end, # Decay to 10% of initial lr
|
||||||
|
)
|
||||||
|
|
||||||
rb = SimpleReplayBuffer(
|
rb = SimpleReplayBuffer(
|
||||||
n_env=args.num_envs,
|
n_env=args.num_envs,
|
||||||
buffer_size=args.buffer_size,
|
buffer_size=args.buffer_size,
|
||||||
@ -353,7 +400,6 @@ def main():
|
|||||||
scaler.step(q_optimizer)
|
scaler.step(q_optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
logs_dict["buffer_rewards"] = rewards.mean()
|
|
||||||
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
logs_dict["critic_grad_norm"] = critic_grad_norm.detach()
|
||||||
logs_dict["qf_loss"] = qf_loss.detach()
|
logs_dict["qf_loss"] = qf_loss.detach()
|
||||||
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
|
logs_dict["qf_max"] = qf1_next_target_value.max().detach()
|
||||||
@ -399,9 +445,15 @@ def main():
|
|||||||
policy = torch.compile(policy, mode=mode)
|
policy = torch.compile(policy, mode=mode)
|
||||||
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
|
normalize_obs = torch.compile(obs_normalizer.forward, mode=mode)
|
||||||
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
|
normalize_critic_obs = torch.compile(critic_obs_normalizer.forward, mode=mode)
|
||||||
|
if args.reward_normalization:
|
||||||
|
update_stats = torch.compile(reward_normalizer.update_stats, mode=mode)
|
||||||
|
normalize_reward = torch.compile(reward_normalizer.forward, mode=mode)
|
||||||
else:
|
else:
|
||||||
normalize_obs = obs_normalizer.forward
|
normalize_obs = obs_normalizer.forward
|
||||||
normalize_critic_obs = critic_obs_normalizer.forward
|
normalize_critic_obs = critic_obs_normalizer.forward
|
||||||
|
if args.reward_normalization:
|
||||||
|
update_stats = reward_normalizer.update_stats
|
||||||
|
normalize_reward = reward_normalizer.forward
|
||||||
|
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
obs, critic_obs = envs.reset_with_critic_obs()
|
obs, critic_obs = envs.reset_with_critic_obs()
|
||||||
@ -447,6 +499,9 @@ def main():
|
|||||||
next_obs, rewards, dones, infos = envs.step(actions.float())
|
next_obs, rewards, dones, infos = envs.step(actions.float())
|
||||||
truncations = infos["time_outs"]
|
truncations = infos["time_outs"]
|
||||||
|
|
||||||
|
if args.reward_normalization:
|
||||||
|
update_stats(rewards, dones.float())
|
||||||
|
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
next_critic_obs = infos["observations"]["critic"]
|
next_critic_obs = infos["observations"]["critic"]
|
||||||
|
|
||||||
@ -494,6 +549,8 @@ def main():
|
|||||||
data["next"]["observations"] = normalize_obs(
|
data["next"]["observations"] = normalize_obs(
|
||||||
data["next"]["observations"]
|
data["next"]["observations"]
|
||||||
)
|
)
|
||||||
|
raw_rewards = data["next"]["rewards"]
|
||||||
|
data["next"]["rewards"] = normalize_reward(raw_rewards)
|
||||||
if envs.asymmetric_obs:
|
if envs.asymmetric_obs:
|
||||||
data["critic_observations"] = normalize_critic_obs(
|
data["critic_observations"] = normalize_critic_obs(
|
||||||
data["critic_observations"]
|
data["critic_observations"]
|
||||||
@ -527,8 +584,8 @@ def main():
|
|||||||
"qf_min": logs_dict["qf_min"].mean(),
|
"qf_min": logs_dict["qf_min"].mean(),
|
||||||
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
|
"actor_grad_norm": logs_dict["actor_grad_norm"].mean(),
|
||||||
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
|
"critic_grad_norm": logs_dict["critic_grad_norm"].mean(),
|
||||||
"buffer_rewards": logs_dict["buffer_rewards"].mean(),
|
|
||||||
"env_rewards": rewards.mean(),
|
"env_rewards": rewards.mean(),
|
||||||
|
"buffer_rewards": raw_rewards.mean(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
if args.eval_interval > 0 and global_step % args.eval_interval == 0:
|
||||||
@ -563,6 +620,8 @@ def main():
|
|||||||
{
|
{
|
||||||
"speed": speed,
|
"speed": speed,
|
||||||
"frame": global_step * args.num_envs,
|
"frame": global_step * args.num_envs,
|
||||||
|
"critic_lr": q_scheduler.get_last_lr()[0],
|
||||||
|
"actor_lr": actor_scheduler.get_last_lr()[0],
|
||||||
**logs,
|
**logs,
|
||||||
},
|
},
|
||||||
step=global_step,
|
step=global_step,
|
||||||
@ -585,6 +644,10 @@ def main():
|
|||||||
f"models/{run_name}_{global_step}.pt",
|
f"models/{run_name}_{global_step}.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update learning rates
|
||||||
|
q_scheduler.step()
|
||||||
|
actor_scheduler.step()
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user