Compare commits

..

No commits in common. "e938018494a1a5b512eb2a40ae9c4d5155d81299" and "df1ba6fe53c1fd3c2b490a28df0e586a49a48df8" have entirely different histories.

7 changed files with 96 additions and 69 deletions

View File

@ -49,7 +49,7 @@ The TRPL implementation in Fancy RL includes projections based on the Kullback-L
To run the test suite: To run the test suite:
```bash ```bash
pytest test/ pytest test/test_ppo.py
``` ```
## Status ## Status

View File

@ -1,12 +1,9 @@
import torch import torch
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
import pdb
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
from fancy_rl.loggers import TerminalLogger from fancy_rl.loggers import TerminalLogger
@ -50,33 +47,18 @@ class Algo(ABC):
self.eval_episodes = eval_episodes self.eval_episodes = eval_episodes
def make_env(self, eval=False): def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec env_spec = self.env_spec_eval if eval else self.env_spec
env = self._wrap_env(env_spec)
env.reset()
return env
def _wrap_env(self, env_spec):
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device) env = gym.make(env_spec)
elif isinstance(env_spec, gym.Env): env = GymWrapper(env).to(self.device)
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec): elif callable(env_spec):
base_env = env_spec() env = env_spec()
return self._wrap_env(base_env) if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
else: else:
raise ValueError( raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -96,7 +78,7 @@ class Algo(ABC):
): ):
with torch.no_grad(): with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}) td = TensorDict({"observation": obs_tensor}, batch_size=[1])
action_td = self.prob_actor(td) action_td = self.prob_actor(td)
action = action_td["action"] action = action_td["action"]

View File

@ -2,9 +2,9 @@ import torch
from torchrl.modules import ProbabilisticActor from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from torchrl.data.tensor_specs import DiscreteTensorSpec
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
@ -41,25 +41,17 @@ class PPO(OnPolicy):
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
obs_space = env.observation_space
# Get spaces from specs for parallel env act_space = env.action_space
obs_space = env.observation_spec
act_space = env.action_spec self.discrete = is_discrete_space(act_space)
self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete: if self.discrete:
distribution_class = torch.distributions.Categorical distribution_class = torch.distributions.Categorical
self.prob_actor = ProbabilisticActor( distribution_kwargs = {"logits": "action_logits"}
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=["logits"],
out_keys=["action"],
)
else: else:
if full_covariance: if full_covariance:
distribution_class = torch.distributions.MultivariateNormal distribution_class = torch.distributions.MultivariateNormal

View File

@ -1,7 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from torchrl.data.tensor_specs import DiscreteTensorSpec
from torchrl.modules import ProbabilisticActor, ValueOperator from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector from torchrl.collectors import SyncDataCollector
@ -11,6 +10,7 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from tensordict import TensorDict from tensordict import TensorDict
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)

View File

@ -1,17 +1,14 @@
import torch.nn as nn import torch.nn as nn
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from torchrl.modules import MLP from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
from tensordict import TensorDict from tensordict import TensorDict
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
self.discrete = isinstance(act_space, DiscreteTensorSpec) self.discrete = is_discrete_space(act_space)
act_space_shape = get_space_shape(act_space)
obs_space = obs_space["observation"]
act_space_shape = act_space.shape[1:]
obs_space_shape = obs_space.shape[1:]
if self.discrete and full_covariance: if self.discrete and full_covariance:
raise ValueError("Full covariance is not applicable for discrete action spaces.") raise ValueError("Full covariance is not applicable for discrete action spaces.")
@ -19,18 +16,18 @@ class Actor(TensorDictModule):
self.full_covariance = full_covariance self.full_covariance = full_covariance
if self.discrete: if self.discrete:
out_features = act_space_shape[0] out_features = act_space_shape[-1]
out_keys = ["logits"] out_keys = ["action_logits"]
else: else:
if full_covariance: if full_covariance:
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2 out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
out_keys = ["loc", "scale_tril"] out_keys = ["loc", "scale_tril"]
else: else:
out_features = act_space_shape[0] * 2 out_features = act_space_shape[-1] * 2
out_keys = ["loc", "scale"] out_keys = ["loc", "scale"]
actor_module = MLP( actor_module = MLP(
in_features=obs_space_shape[0], in_features=get_space_shape(obs_space)[-1],
out_features=out_features, out_features=out_features,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),
@ -39,7 +36,7 @@ class Actor(TensorDictModule):
if not self.discrete: if not self.discrete:
if full_covariance: if full_covariance:
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0]) param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
else: else:
param_extractor = NormalParamExtractor() param_extractor = NormalParamExtractor()
actor_module = nn.Sequential(actor_module, param_extractor) actor_module = nn.Sequential(actor_module, param_extractor)
@ -66,7 +63,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP( critic_module = MLP(
in_features=obs_space.shape[-1], in_features=get_space_shape(obs_space)[-1],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),

View File

@ -0,0 +1,61 @@
import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import (
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space):
discrete_types = (
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
)
continuous_types = (
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
return True
elif isinstance(action_space, continuous_types):
return False
else:
raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space):
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
return (action_space.n,)
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
return (sum(action_space.nvec),)
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
return (action_space.n,)
elif gym_available:
if isinstance(action_space, GymDiscrete):
return (action_space.n,)
elif isinstance(action_space, GymMultiDiscrete):
return (sum(action_space.nvec),)
elif isinstance(action_space, GymMultiBinary):
return (action_space.n,)
elif isinstance(action_space, continuous_types):
return action_space.shape
raise ValueError(f"Unsupported action space type: {type(action_space)}")

View File

@ -2,10 +2,9 @@ import pytest
import numpy as np import numpy as np
from fancy_rl import PPO from fancy_rl import PPO
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv
def simple_env(): def simple_env():
return GymEnv('LunarLander-v2', continuous=True) return gym.make('LunarLander-v2', continuous=True)
def test_ppo_instantiation(): def test_ppo_instantiation():
ppo = PPO(simple_env) ppo = PPO(simple_env)
@ -15,10 +14,6 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1') ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO) assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3]) @pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048]) @pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128]) @pytest.mark.parametrize("batch_size", [32, 64, 128])
@ -53,12 +48,12 @@ def test_ppo_predict():
def test_ppo_learn(): def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32) ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env() env = ppo.make_env()
obs = env.reset() obs, _ = env.reset()
for _ in range(64): for _ in range(64):
action, _next_state = ppo.predict(obs) action, _ = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action) obs, reward, done, truncated, _ = env.step(action)
if done or truncated: if done or truncated:
obs = env.reset() obs, _ = env.reset()
def test_ppo_training(): def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000) ppo = PPO(simple_env, total_timesteps=10000)
@ -73,10 +68,10 @@ def test_ppo_training():
def evaluate_policy(policy, env, n_eval_episodes=10): def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0 total_reward = 0
for _ in range(n_eval_episodes): for _ in range(n_eval_episodes):
obs = env.reset() obs, _ = env.reset()
done = False done = False
while not done: while not done:
action, _next_state = policy.predict(obs) action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action) obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward total_reward += reward
done = terminated or truncated done = terminated or truncated