diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index d462ff8..0529241 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -2,9 +2,9 @@ import torch from torchrl.modules import ProbabilisticActor from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE +from torchrl.data.tensor_specs import DiscreteTensorSpec from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic -from fancy_rl.utils import is_discrete_space class PPO(OnPolicy): def __init__( @@ -41,17 +41,25 @@ class PPO(OnPolicy): # Initialize environment to get observation and action space sizes self.env_spec = env_spec env = self.make_env() - obs_space = env.observation_space - act_space = env.action_space - - self.discrete = is_discrete_space(act_space) + + # Get spaces from specs for parallel env + obs_space = env.observation_spec + act_space = env.action_spec + + self.discrete = isinstance(act_space, DiscreteTensorSpec) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) if self.discrete: distribution_class = torch.distributions.Categorical - distribution_kwargs = {"logits": "action_logits"} + self.prob_actor = ProbabilisticActor( + module=self.actor, + distribution_class=distribution_class, + return_log_prob=True, + in_keys=["logits"], + out_keys=["action"], + ) else: if full_covariance: distribution_class = torch.distributions.MultivariateNormal diff --git a/fancy_rl/algos/trpl.py b/fancy_rl/algos/trpl.py index aa94a8b..c483fff 100644 --- a/fancy_rl/algos/trpl.py +++ b/fancy_rl/algos/trpl.py @@ -1,6 +1,7 @@ import torch from torch import nn from typing import Dict, Any, Optional +from torchrl.data.tensor_specs import DiscreteTensorSpec from torchrl.modules import ProbabilisticActor, ValueOperator from torchrl.objectives import ClipPPOLoss from torchrl.collectors import SyncDataCollector @@ -10,7 +11,6 @@ from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.objectives import TRPLLoss -from fancy_rl.utils import is_discrete_space from copy import deepcopy from tensordict.nn import TensorDictModule from tensordict import TensorDict @@ -80,7 +80,7 @@ class TRPL(OnPolicy): obs_space = env.observation_space act_space = env.action_space - assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces" + assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)