Simplify operations on spaces (is_discrete, shape)
This commit is contained in:
parent
5c44448e53
commit
4f8fc500b7
@ -1,14 +1,17 @@
|
||||
import torch.nn as nn
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.modules import MLP
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from tensordict.nn.distributions import NormalParamExtractor
|
||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||
from tensordict import TensorDict
|
||||
|
||||
class Actor(TensorDictModule):
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
act_space_shape = get_space_shape(act_space)
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
|
||||
obs_space = obs_space["observation"]
|
||||
act_space_shape = act_space.shape[1:]
|
||||
obs_space_shape = obs_space.shape[1:]
|
||||
|
||||
if self.discrete and full_covariance:
|
||||
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||
@ -16,18 +19,18 @@ class Actor(TensorDictModule):
|
||||
self.full_covariance = full_covariance
|
||||
|
||||
if self.discrete:
|
||||
out_features = act_space_shape[-1]
|
||||
out_keys = ["action_logits"]
|
||||
out_features = act_space_shape[0]
|
||||
out_keys = ["logits"]
|
||||
else:
|
||||
if full_covariance:
|
||||
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
||||
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
|
||||
out_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
out_features = act_space_shape[-1] * 2
|
||||
out_features = act_space_shape[0] * 2
|
||||
out_keys = ["loc", "scale"]
|
||||
|
||||
actor_module = MLP(
|
||||
in_features=get_space_shape(obs_space)[-1],
|
||||
in_features=obs_space_shape[0],
|
||||
out_features=out_features,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
@ -36,7 +39,7 @@ class Actor(TensorDictModule):
|
||||
|
||||
if not self.discrete:
|
||||
if full_covariance:
|
||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
|
||||
else:
|
||||
param_extractor = NormalParamExtractor()
|
||||
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||
@ -63,7 +66,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
class Critic(TensorDictModule):
|
||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||
critic_module = MLP(
|
||||
in_features=get_space_shape(obs_space)[-1],
|
||||
in_features=obs_space.shape[-1],
|
||||
out_features=1,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
|
@ -1,61 +0,0 @@
|
||||
import gymnasium
|
||||
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
|
||||
from torchrl.data.tensor_specs import (
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
|
||||
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
try:
|
||||
import gym
|
||||
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
|
||||
gym_available = True
|
||||
except ImportError:
|
||||
gym_available = False
|
||||
|
||||
def is_discrete_space(action_space):
|
||||
discrete_types = (
|
||||
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
|
||||
)
|
||||
continuous_types = (
|
||||
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
|
||||
)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
return True
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||
|
||||
def get_space_shape(action_space):
|
||||
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif gym_available:
|
||||
if isinstance(action_space, GymDiscrete):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, GymMultiDiscrete):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, GymMultiBinary):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return action_space.shape
|
||||
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
Loading…
Reference in New Issue
Block a user