Use trl space definitions (not gym)

This commit is contained in:
Dominik Moritz Roth 2024-11-07 11:39:45 +01:00
parent 8a078fb59e
commit 5c44448e53
2 changed files with 16 additions and 8 deletions

View File

@ -2,9 +2,9 @@ import torch
from torchrl.modules import ProbabilisticActor from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from torchrl.data.tensor_specs import DiscreteTensorSpec
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
@ -41,17 +41,25 @@ class PPO(OnPolicy):
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space # Get spaces from specs for parallel env
obs_space = env.observation_spec
self.discrete = is_discrete_space(act_space) act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete: if self.discrete:
distribution_class = torch.distributions.Categorical distribution_class = torch.distributions.Categorical
distribution_kwargs = {"logits": "action_logits"} self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=["logits"],
out_keys=["action"],
)
else: else:
if full_covariance: if full_covariance:
distribution_class = torch.distributions.MultivariateNormal distribution_class = torch.distributions.MultivariateNormal

View File

@ -1,6 +1,7 @@
import torch import torch
from torch import nn from torch import nn
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from torchrl.data.tensor_specs import DiscreteTensorSpec
from torchrl.modules import ProbabilisticActor, ValueOperator from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector from torchrl.collectors import SyncDataCollector
@ -10,7 +11,6 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from tensordict import TensorDict from tensordict import TensorDict
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces" assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)