FastTD3/fast_td3/fast_td3_deploy.py
2025-05-29 01:49:23 +00:00

67 lines
1.9 KiB
Python

import torch
import torch.nn as nn
from .fast_td3_utils import EmpiricalNormalization
from .fast_td3 import Actor
class Policy(nn.Module):
def __init__(
self,
n_obs: int,
n_act: int,
num_envs: int,
init_scale: float,
actor_hidden_dim: int,
):
super().__init__()
self.actor = Actor(
n_obs=n_obs,
n_act=n_act,
num_envs=num_envs,
device="cpu",
init_scale=init_scale,
hidden_dim=actor_hidden_dim,
)
self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu")
self.actor.eval()
self.obs_normalizer.eval()
@torch.no_grad
def forward(self, obs: torch.Tensor) -> torch.Tensor:
norm_obs = self.obs_normalizer(obs)
actions = self.actor(norm_obs)
return actions
@torch.no_grad
def act(self, obs: torch.Tensor) -> torch.distributions.Normal:
actions = self.forward(obs)
return torch.distributions.Normal(actions, torch.ones_like(actions) * 1e-8)
def load_policy(checkpoint_path):
torch_checkpoint = torch.load(
f"{checkpoint_path}", map_location="cpu", weights_only=False
)
args = torch_checkpoint["args"]
n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1]
n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0]
policy = Policy(
n_obs=n_obs,
n_act=n_act,
num_envs=args["num_envs"],
init_scale=args["init_scale"],
actor_hidden_dim=args["actor_hidden_dim"],
)
policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"])
if len(torch_checkpoint["obs_normalizer_state"]) == 0:
policy.obs_normalizer = nn.Identity()
else:
policy.obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
return policy