67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from .fast_td3_utils import EmpiricalNormalization
|
|
from .fast_td3 import Actor
|
|
|
|
|
|
class Policy(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_obs: int,
|
|
n_act: int,
|
|
num_envs: int,
|
|
init_scale: float,
|
|
actor_hidden_dim: int,
|
|
):
|
|
super().__init__()
|
|
self.actor = Actor(
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
num_envs=num_envs,
|
|
device="cpu",
|
|
init_scale=init_scale,
|
|
hidden_dim=actor_hidden_dim,
|
|
)
|
|
self.obs_normalizer = EmpiricalNormalization(shape=n_obs, device="cpu")
|
|
|
|
self.actor.eval()
|
|
self.obs_normalizer.eval()
|
|
|
|
@torch.no_grad
|
|
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
|
norm_obs = self.obs_normalizer(obs)
|
|
actions = self.actor(norm_obs)
|
|
return actions
|
|
|
|
@torch.no_grad
|
|
def act(self, obs: torch.Tensor) -> torch.distributions.Normal:
|
|
actions = self.forward(obs)
|
|
return torch.distributions.Normal(actions, torch.ones_like(actions) * 1e-8)
|
|
|
|
|
|
def load_policy(checkpoint_path):
|
|
torch_checkpoint = torch.load(
|
|
f"{checkpoint_path}", map_location="cpu", weights_only=False
|
|
)
|
|
args = torch_checkpoint["args"]
|
|
|
|
n_obs = torch_checkpoint["actor_state_dict"]["net.0.weight"].shape[-1]
|
|
n_act = torch_checkpoint["actor_state_dict"]["fc_mu.0.weight"].shape[0]
|
|
|
|
policy = Policy(
|
|
n_obs=n_obs,
|
|
n_act=n_act,
|
|
num_envs=args["num_envs"],
|
|
init_scale=args["init_scale"],
|
|
actor_hidden_dim=args["actor_hidden_dim"],
|
|
)
|
|
|
|
policy.actor.load_state_dict(torch_checkpoint["actor_state_dict"])
|
|
|
|
if len(torch_checkpoint["obs_normalizer_state"]) == 0:
|
|
policy.obs_normalizer = nn.Identity()
|
|
else:
|
|
policy.obs_normalizer.load_state_dict(torch_checkpoint["obs_normalizer_state"])
|
|
|
|
return policy
|