From e927afcc301e742094b9c198763e5ec2055831ca Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 21 Oct 2024 15:24:20 +0200 Subject: [PATCH] Extend policy impl --- fancy_rl/policy.py | 62 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index aa96b63..6209e60 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -3,31 +3,63 @@ from tensordict.nn import TensorDictModule from torchrl.modules import MLP from tensordict.nn.distributions import NormalParamExtractor from fancy_rl.utils import is_discrete_space, get_space_shape +from tensordict import TensorDict class Actor(TensorDictModule): - def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device): + def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): + self.discrete = is_discrete_space(act_space) act_space_shape = get_space_shape(act_space) - if is_discrete_space(act_space): - out_features = act_space_shape[-1] - else: - out_features = act_space_shape[-1] * 2 + + if self.discrete and full_covariance: + raise ValueError("Full covariance is not applicable for discrete action spaces.") + + self.full_covariance = full_covariance - actor_module = nn.Sequential( - MLP( - in_features=get_space_shape(obs_space)[-1], - out_features=out_features, - num_cells=hidden_sizes, - activation_class=getattr(nn, activation_fn), - device=device - ), - NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(), + if self.discrete: + out_features = act_space_shape[-1] + out_keys = ["action_logits"] + else: + if full_covariance: + out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2 + out_keys = ["loc", "scale_tril"] + else: + out_features = act_space_shape[-1] * 2 + out_keys = ["loc", "scale"] + + actor_module = MLP( + in_features=get_space_shape(obs_space)[-1], + out_features=out_features, + num_cells=hidden_sizes, + activation_class=getattr(nn, activation_fn), + device=device ).to(device) + + if not self.discrete: + if full_covariance: + param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1]) + else: + param_extractor = NormalParamExtractor() + actor_module = nn.Sequential(actor_module, param_extractor) + super().__init__( module=actor_module, in_keys=["observation"], - out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], + out_keys=out_keys ) +class FullCovarianceNormalParamExtractor(nn.Module): + def __init__(self, action_dim): + super().__init__() + self.action_dim = action_dim + + def forward(self, x): + loc = x[:, :self.action_dim] + scale_tril = torch.zeros(x.shape[0], self.action_dim, self.action_dim, device=x.device) + tril_indices = torch.tril_indices(row=self.action_dim, col=self.action_dim, offset=0) + scale_tril[:, tril_indices[0], tril_indices[1]] = x[:, self.action_dim:] + scale_tril.diagonal(dim1=-2, dim2=-1).exp_() + return TensorDict({"loc": loc, "scale_tril": scale_tril}, batch_size=x.shape[0]) + class Critic(TensorDictModule): def __init__(self, obs_space, hidden_sizes, activation_fn, device): critic_module = MLP(