Simplify operations on spaces (is_discrete, shape)
This commit is contained in:
parent
5c44448e53
commit
4f8fc500b7
@ -1,14 +1,17 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from torchrl.modules import MLP
|
from torchrl.modules import MLP
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||||
self.discrete = is_discrete_space(act_space)
|
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||||
act_space_shape = get_space_shape(act_space)
|
|
||||||
|
obs_space = obs_space["observation"]
|
||||||
|
act_space_shape = act_space.shape[1:]
|
||||||
|
obs_space_shape = obs_space.shape[1:]
|
||||||
|
|
||||||
if self.discrete and full_covariance:
|
if self.discrete and full_covariance:
|
||||||
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||||
@ -16,18 +19,18 @@ class Actor(TensorDictModule):
|
|||||||
self.full_covariance = full_covariance
|
self.full_covariance = full_covariance
|
||||||
|
|
||||||
if self.discrete:
|
if self.discrete:
|
||||||
out_features = act_space_shape[-1]
|
out_features = act_space_shape[0]
|
||||||
out_keys = ["action_logits"]
|
out_keys = ["logits"]
|
||||||
else:
|
else:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
|
||||||
out_keys = ["loc", "scale_tril"]
|
out_keys = ["loc", "scale_tril"]
|
||||||
else:
|
else:
|
||||||
out_features = act_space_shape[-1] * 2
|
out_features = act_space_shape[0] * 2
|
||||||
out_keys = ["loc", "scale"]
|
out_keys = ["loc", "scale"]
|
||||||
|
|
||||||
actor_module = MLP(
|
actor_module = MLP(
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
in_features=obs_space_shape[0],
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
@ -36,7 +39,7 @@ class Actor(TensorDictModule):
|
|||||||
|
|
||||||
if not self.discrete:
|
if not self.discrete:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
|
||||||
else:
|
else:
|
||||||
param_extractor = NormalParamExtractor()
|
param_extractor = NormalParamExtractor()
|
||||||
actor_module = nn.Sequential(actor_module, param_extractor)
|
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||||
@ -63,7 +66,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
|||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
critic_module = MLP(
|
critic_module = MLP(
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
in_features=obs_space.shape[-1],
|
||||||
out_features=1,
|
out_features=1,
|
||||||
num_cells=hidden_sizes,
|
num_cells=hidden_sizes,
|
||||||
activation_class=getattr(nn, activation_fn),
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
import gymnasium
|
|
||||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
|
||||||
from torchrl.data.tensor_specs import (
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
|
|
||||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import gym
|
|
||||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
|
||||||
gym_available = True
|
|
||||||
except ImportError:
|
|
||||||
gym_available = False
|
|
||||||
|
|
||||||
def is_discrete_space(action_space):
|
|
||||||
discrete_types = (
|
|
||||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
|
||||||
)
|
|
||||||
continuous_types = (
|
|
||||||
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
|
||||||
)
|
|
||||||
|
|
||||||
if gym_available:
|
|
||||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
|
||||||
continuous_types += (GymBox,)
|
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
|
||||||
return True
|
|
||||||
elif isinstance(action_space, continuous_types):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
|
||||||
|
|
||||||
def get_space_shape(action_space):
|
|
||||||
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
|
||||||
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
|
||||||
|
|
||||||
if gym_available:
|
|
||||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
|
||||||
continuous_types += (GymBox,)
|
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
|
||||||
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
|
||||||
return (sum(action_space.nvec),)
|
|
||||||
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif gym_available:
|
|
||||||
if isinstance(action_space, GymDiscrete):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, GymMultiDiscrete):
|
|
||||||
return (sum(action_space.nvec),)
|
|
||||||
elif isinstance(action_space, GymMultiBinary):
|
|
||||||
return (action_space.n,)
|
|
||||||
elif isinstance(action_space, continuous_types):
|
|
||||||
return action_space.shape
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
|
Loading…
Reference in New Issue
Block a user