* v0.5 (#9) * update idql configs * update awr configs * update dipo configs * update qsm configs * update dqm configs * update project version to 0.5.0
121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
"""
|
|
Gaussian policy parameterization.
|
|
|
|
"""
|
|
|
|
import torch
|
|
import torch.distributions as D
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class GaussianModel(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
network,
|
|
horizon_steps,
|
|
network_path=None,
|
|
device="cuda:0",
|
|
randn_clip_value=10,
|
|
tanh_output=False,
|
|
):
|
|
super().__init__()
|
|
self.device = device
|
|
self.network = network.to(device)
|
|
if network_path is not None:
|
|
checkpoint = torch.load(
|
|
network_path,
|
|
map_location=self.device,
|
|
weights_only=True,
|
|
)
|
|
self.load_state_dict(
|
|
checkpoint["model"],
|
|
strict=False,
|
|
)
|
|
log.info("Loaded actor from %s", network_path)
|
|
log.info(
|
|
f"Number of network parameters: {sum(p.numel() for p in self.parameters())}"
|
|
)
|
|
self.horizon_steps = horizon_steps
|
|
|
|
# Clip sampled randn (from standard deviation) such that the sampled action is not too far away from mean
|
|
self.randn_clip_value = randn_clip_value
|
|
|
|
# Whether to apply tanh to the **sampled** action --- used in SAC
|
|
self.tanh_output = tanh_output
|
|
|
|
def loss(
|
|
self,
|
|
true_action,
|
|
cond,
|
|
ent_coef,
|
|
):
|
|
"""no squashing"""
|
|
B = len(true_action)
|
|
dist = self.forward_train(
|
|
cond,
|
|
deterministic=False,
|
|
)
|
|
true_action = true_action.view(B, -1)
|
|
loss = -dist.log_prob(true_action) # [B]
|
|
entropy = dist.entropy().mean()
|
|
loss = loss.mean() - entropy * ent_coef
|
|
return loss, {"entropy": entropy}
|
|
|
|
def forward_train(
|
|
self,
|
|
cond,
|
|
deterministic=False,
|
|
network_override=None,
|
|
):
|
|
"""
|
|
Calls the MLP to compute the mean, scale, and logits of the GMM. Returns the torch.Distribution object.
|
|
"""
|
|
if network_override is not None:
|
|
means, scales = network_override(cond)
|
|
else:
|
|
means, scales = self.network(cond)
|
|
if deterministic:
|
|
# low-noise for all Gaussian dists
|
|
scales = torch.ones_like(means) * 1e-4
|
|
return D.Normal(loc=means, scale=scales)
|
|
|
|
def forward(
|
|
self,
|
|
cond,
|
|
deterministic=False,
|
|
network_override=None,
|
|
reparameterize=False,
|
|
get_logprob=False,
|
|
):
|
|
B = len(cond["state"]) if "state" in cond else len(cond["rgb"])
|
|
T = self.horizon_steps
|
|
dist = self.forward_train(
|
|
cond,
|
|
deterministic=deterministic,
|
|
network_override=network_override,
|
|
)
|
|
if reparameterize:
|
|
sampled_action = dist.rsample()
|
|
else:
|
|
sampled_action = dist.sample()
|
|
sampled_action.clamp_(
|
|
dist.loc - self.randn_clip_value * dist.scale,
|
|
dist.loc + self.randn_clip_value * dist.scale,
|
|
)
|
|
|
|
if get_logprob:
|
|
log_prob = dist.log_prob(sampled_action)
|
|
|
|
# For SAC/RLPD, squash mean after sampling here instead of right after model output as in PPO
|
|
if self.tanh_output:
|
|
sampled_action = torch.tanh(sampled_action)
|
|
log_prob -= torch.log(1 - sampled_action.pow(2) + 1e-6)
|
|
return sampled_action.view(B, T, -1), log_prob.sum(1, keepdim=False)
|
|
else:
|
|
if self.tanh_output:
|
|
sampled_action = torch.tanh(sampled_action)
|
|
return sampled_action.view(B, T, -1)
|