Simplify operations on spaces (is_discrete, shape)

This commit is contained in:
Dominik Moritz Roth 2024-11-07 11:40:32 +01:00
parent 5c44448e53
commit 4f8fc500b7
2 changed files with 13 additions and 71 deletions

View File

@ -1,14 +1,17 @@
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
from tensordict import TensorDict
class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
self.discrete = is_discrete_space(act_space)
act_space_shape = get_space_shape(act_space)
self.discrete = isinstance(act_space, DiscreteTensorSpec)
obs_space = obs_space["observation"]
act_space_shape = act_space.shape[1:]
obs_space_shape = obs_space.shape[1:]
if self.discrete and full_covariance:
raise ValueError("Full covariance is not applicable for discrete action spaces.")
@ -16,18 +19,18 @@ class Actor(TensorDictModule):
self.full_covariance = full_covariance
if self.discrete:
out_features = act_space_shape[-1]
out_keys = ["action_logits"]
out_features = act_space_shape[0]
out_keys = ["logits"]
else:
if full_covariance:
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
out_keys = ["loc", "scale_tril"]
else:
out_features = act_space_shape[-1] * 2
out_features = act_space_shape[0] * 2
out_keys = ["loc", "scale"]
actor_module = MLP(
in_features=get_space_shape(obs_space)[-1],
in_features=obs_space_shape[0],
out_features=out_features,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
@ -36,7 +39,7 @@ class Actor(TensorDictModule):
if not self.discrete:
if full_covariance:
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
else:
param_extractor = NormalParamExtractor()
actor_module = nn.Sequential(actor_module, param_extractor)
@ -63,7 +66,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP(
in_features=get_space_shape(obs_space)[-1],
in_features=obs_space.shape[-1],
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),

View File

@ -1,61 +0,0 @@
import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import (
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space):
discrete_types = (
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
)
continuous_types = (
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
return True
elif isinstance(action_space, continuous_types):
return False
else:
raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space):
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
return (action_space.n,)
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
return (sum(action_space.nvec),)
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
return (action_space.n,)
elif gym_available:
if isinstance(action_space, GymDiscrete):
return (action_space.n,)
elif isinstance(action_space, GymMultiDiscrete):
return (sum(action_space.nvec),)
elif isinstance(action_space, GymMultiBinary):
return (action_space.n,)
elif isinstance(action_space, continuous_types):
return action_space.shape
raise ValueError(f"Unsupported action space type: {type(action_space)}")