From 50733bb1a492db39ae16ed2062a4b51d1429fa67 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 2 Jun 2024 11:07:57 +0200 Subject: [PATCH] Make use of new utils to work with any space --- fancy_rl/policy.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 9146eb0..51d708b 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -2,21 +2,22 @@ import torch.nn as nn from tensordict.nn import TensorDictModule from torchrl.modules import MLP from tensordict.nn.distributions import NormalParamExtractor +from fancy_rl.utils import is_discrete_space, get_space_shape class SharedModule(TensorDictModule): def __init__(self, obs_space, hidden_sizes, activation_fn, device): if hidden_sizes: shared_module = MLP( - in_features=obs_space.shape[-1], + in_features=get_space_shape(obs_space)[-1], out_features=hidden_sizes[-1], - num_cells=hidden_sizes, + num_cells=hidden_sizes[:-1], activation_class=getattr(nn, activation_fn), device=device ) out_features = hidden_sizes[-1] else: shared_module = nn.Identity() - out_features = obs_space.shape[-1] + out_features = get_space_shape(obs_space)[-1] super().__init__( module=shared_module, @@ -27,20 +28,26 @@ class SharedModule(TensorDictModule): class Actor(TensorDictModule): def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): + act_space_shape = get_space_shape(act_space) + if is_discrete_space(act_space): + out_features = act_space_shape[-1] + else: + out_features = act_space_shape[-1] * 2 + actor_module = nn.Sequential( MLP( in_features=shared_module.out_features, - out_features=act_space.shape[-1] * 2, + out_features=out_features, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn), device=device ), - NormalParamExtractor(), + NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(), ).to(device) super().__init__( module=actor_module, in_keys=["shared"], - out_keys=["loc", "scale"], + out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], ) class Critic(TensorDictModule):