Use trl space definitions (not gym)
This commit is contained in:
parent
8a078fb59e
commit
5c44448e53
@ -2,9 +2,9 @@ import torch
|
|||||||
from torchrl.modules import ProbabilisticActor
|
from torchrl.modules import ProbabilisticActor
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.utils import is_discrete_space
|
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -41,17 +41,25 @@ class PPO(OnPolicy):
|
|||||||
# Initialize environment to get observation and action space sizes
|
# Initialize environment to get observation and action space sizes
|
||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
env = self.make_env()
|
env = self.make_env()
|
||||||
obs_space = env.observation_space
|
|
||||||
act_space = env.action_space
|
|
||||||
|
|
||||||
self.discrete = is_discrete_space(act_space)
|
# Get spaces from specs for parallel env
|
||||||
|
obs_space = env.observation_spec
|
||||||
|
act_space = env.action_spec
|
||||||
|
|
||||||
|
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
if self.discrete:
|
if self.discrete:
|
||||||
distribution_class = torch.distributions.Categorical
|
distribution_class = torch.distributions.Categorical
|
||||||
distribution_kwargs = {"logits": "action_logits"}
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=["logits"],
|
||||||
|
out_keys=["action"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if full_covariance:
|
if full_covariance:
|
||||||
distribution_class = torch.distributions.MultivariateNormal
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.collectors import SyncDataCollector
|
from torchrl.collectors import SyncDataCollector
|
||||||
@ -10,7 +11,6 @@ from fancy_rl.algos.on_policy import OnPolicy
|
|||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection, BaseProjection
|
from fancy_rl.projections import get_projection, BaseProjection
|
||||||
from fancy_rl.objectives import TRPLLoss
|
from fancy_rl.objectives import TRPLLoss
|
||||||
from fancy_rl.utils import is_discrete_space
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
Loading…
Reference in New Issue
Block a user