Use trl space definitions (not gym)
This commit is contained in:
parent
8a078fb59e
commit
5c44448e53
@ -2,9 +2,9 @@ import torch
|
||||
from torchrl.modules import ProbabilisticActor
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
|
||||
class PPO(OnPolicy):
|
||||
def __init__(
|
||||
@ -41,17 +41,25 @@ class PPO(OnPolicy):
|
||||
# Initialize environment to get observation and action space sizes
|
||||
self.env_spec = env_spec
|
||||
env = self.make_env()
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
# Get spaces from specs for parallel env
|
||||
obs_space = env.observation_spec
|
||||
act_space = env.action_spec
|
||||
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
distribution_kwargs = {"logits": "action_logits"}
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=["logits"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
else:
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, Any, Optional
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.collectors import SyncDataCollector
|
||||
@ -10,7 +11,6 @@ from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection, BaseProjection
|
||||
from fancy_rl.objectives import TRPLLoss
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
Loading…
Reference in New Issue
Block a user