diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 6209e60..6ce1590 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -1,14 +1,17 @@ import torch.nn as nn from tensordict.nn import TensorDictModule from torchrl.modules import MLP +from torchrl.data.tensor_specs import DiscreteTensorSpec from tensordict.nn.distributions import NormalParamExtractor -from fancy_rl.utils import is_discrete_space, get_space_shape from tensordict import TensorDict class Actor(TensorDictModule): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): - self.discrete = is_discrete_space(act_space) - act_space_shape = get_space_shape(act_space) + self.discrete = isinstance(act_space, DiscreteTensorSpec) + + obs_space = obs_space["observation"] + act_space_shape = act_space.shape[1:] + obs_space_shape = obs_space.shape[1:] if self.discrete and full_covariance: raise ValueError("Full covariance is not applicable for discrete action spaces.") @@ -16,18 +19,18 @@ class Actor(TensorDictModule): self.full_covariance = full_covariance if self.discrete: - out_features = act_space_shape[-1] - out_keys = ["action_logits"] + out_features = act_space_shape[0] + out_keys = ["logits"] else: if full_covariance: - out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2 + out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2 out_keys = ["loc", "scale_tril"] else: - out_features = act_space_shape[-1] * 2 + out_features = act_space_shape[0] * 2 out_keys = ["loc", "scale"] actor_module = MLP( - in_features=get_space_shape(obs_space)[-1], + in_features=obs_space_shape[0], out_features=out_features, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn), @@ -36,7 +39,7 @@ class Actor(TensorDictModule): if not self.discrete: if full_covariance: - param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1]) + param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0]) else: param_extractor = NormalParamExtractor() actor_module = nn.Sequential(actor_module, param_extractor) @@ -63,7 +66,7 @@ class FullCovarianceNormalParamExtractor(nn.Module): class Critic(TensorDictModule): def __init__(self, obs_space, hidden_sizes, activation_fn, device): critic_module = MLP( - in_features=get_space_shape(obs_space)[-1], + in_features=obs_space.shape[-1], out_features=1, num_cells=hidden_sizes, activation_class=getattr(nn, activation_fn), diff --git a/fancy_rl/utils.py b/fancy_rl/utils.py index 0469aaa..e69de29 100644 --- a/fancy_rl/utils.py +++ b/fancy_rl/utils.py @@ -1,61 +0,0 @@ -import gymnasium -from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox -from torchrl.data.tensor_specs import ( - DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, - BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec -) - -try: - import gym - from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox - gym_available = True -except ImportError: - gym_available = False - -def is_discrete_space(action_space): - discrete_types = ( - GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, - DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec - ) - continuous_types = ( - GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec - ) - - if gym_available: - discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary) - continuous_types += (GymBox,) - - if isinstance(action_space, discrete_types): - return True - elif isinstance(action_space, continuous_types): - return False - else: - raise ValueError(f"Unsupported action space type: {type(action_space)}") - -def get_space_shape(action_space): - discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary, - DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec) - continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec) - - if gym_available: - discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary) - continuous_types += (GymBox,) - - if isinstance(action_space, discrete_types): - if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)): - return (action_space.n,) - elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)): - return (sum(action_space.nvec),) - elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)): - return (action_space.n,) - elif gym_available: - if isinstance(action_space, GymDiscrete): - return (action_space.n,) - elif isinstance(action_space, GymMultiDiscrete): - return (sum(action_space.nvec),) - elif isinstance(action_space, GymMultiBinary): - return (action_space.n,) - elif isinstance(action_space, continuous_types): - return action_space.shape - - raise ValueError(f"Unsupported action space type: {type(action_space)}")