Make use of new utils to work with any space
This commit is contained in:
parent
59060c7533
commit
50733bb1a4
@ -2,21 +2,22 @@ import torch.nn as nn
|
|||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from torchrl.modules import MLP
|
from torchrl.modules import MLP
|
||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
|
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||||
|
|
||||||
class SharedModule(TensorDictModule):
|
class SharedModule(TensorDictModule):
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
if hidden_sizes:
|
if hidden_sizes:
|
||||||
shared_module = MLP(
|
shared_module = MLP(
|
||||||
in_features=obs_space.shape[-1],
|
in_features=get_space_shape(obs_space)[-1],
|
||||||
out_features=hidden_sizes[-1],
|
out_features=hidden_sizes[-1],
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes[:-1],
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
out_features = hidden_sizes[-1]
|
out_features = hidden_sizes[-1]
|
||||||
else:
|
else:
|
||||||
shared_module = nn.Identity()
|
shared_module = nn.Identity()
|
||||||
out_features = obs_space.shape[-1]
|
out_features = get_space_shape(obs_space)[-1]
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
module=shared_module,
|
module=shared_module,
|
||||||
@ -27,20 +28,26 @@ class SharedModule(TensorDictModule):
|
|||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
|
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
|
||||||
|
act_space_shape = get_space_shape(act_space)
|
||||||
|
if is_discrete_space(act_space):
|
||||||
|
out_features = act_space_shape[-1]
|
||||||
|
else:
|
||||||
|
out_features = act_space_shape[-1] * 2
|
||||||
|
|
||||||
actor_module = nn.Sequential(
|
actor_module = nn.Sequential(
|
||||||
MLP(
|
MLP(
|
||||||
in_features=shared_module.out_features,
|
in_features=shared_module.out_features,
|
||||||
out_features=act_space.shape[-1] * 2,
|
out_features=out_features,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
device=device
|
device=device
|
||||||
),
|
),
|
||||||
NormalParamExtractor(),
|
NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
|
||||||
).to(device)
|
).to(device)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
module=actor_module,
|
module=actor_module,
|
||||||
in_keys=["shared"],
|
in_keys=["shared"],
|
||||||
out_keys=["loc", "scale"],
|
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
|
||||||
)
|
)
|
||||||
|
|
||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
|
Loading…
Reference in New Issue
Block a user