Make use of new utils to work with any space

This commit is contained in:
Dominik Moritz Roth 2024-06-02 11:07:57 +02:00
parent 59060c7533
commit 50733bb1a4

View File

@ -2,21 +2,22 @@ import torch.nn as nn
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from torchrl.modules import MLP from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
class SharedModule(TensorDictModule): class SharedModule(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
if hidden_sizes: if hidden_sizes:
shared_module = MLP( shared_module = MLP(
in_features=obs_space.shape[-1], in_features=get_space_shape(obs_space)[-1],
out_features=hidden_sizes[-1], out_features=hidden_sizes[-1],
num_cells=hidden_sizes, num_cells=hidden_sizes[:-1],
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
device=device device=device
) )
out_features = hidden_sizes[-1] out_features = hidden_sizes[-1]
else: else:
shared_module = nn.Identity() shared_module = nn.Identity()
out_features = obs_space.shape[-1] out_features = get_space_shape(obs_space)[-1]
super().__init__( super().__init__(
module=shared_module, module=shared_module,
@ -27,20 +28,26 @@ class SharedModule(TensorDictModule):
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
act_space_shape = get_space_shape(act_space)
if is_discrete_space(act_space):
out_features = act_space_shape[-1]
else:
out_features = act_space_shape[-1] * 2
actor_module = nn.Sequential( actor_module = nn.Sequential(
MLP( MLP(
in_features=shared_module.out_features, in_features=shared_module.out_features,
out_features=act_space.shape[-1] * 2, out_features=out_features,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
device=device device=device
), ),
NormalParamExtractor(), NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
).to(device) ).to(device)
super().__init__( super().__init__(
module=actor_module, module=actor_module,
in_keys=["shared"], in_keys=["shared"],
out_keys=["loc", "scale"], out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
) )
class Critic(TensorDictModule): class Critic(TensorDictModule):