Compare commits

..

No commits in common. "e938018494a1a5b512eb2a40ae9c4d5155d81299" and "df1ba6fe53c1fd3c2b490a28df0e586a49a48df8" have entirely different histories.

7 changed files with 96 additions and 69 deletions

View File

@ -49,7 +49,7 @@ The TRPL implementation in Fancy RL includes projections based on the Kullback-L
To run the test suite:
```bash
pytest test/
pytest test/test_ppo.py
```
## Status

View File

@ -1,12 +1,9 @@
import torch
import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
from abc import ABC
import pdb
from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
from fancy_rl.loggers import TerminalLogger
@ -50,33 +47,18 @@ class Algo(ABC):
self.eval_episodes = eval_episodes
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
env = self._wrap_env(env_spec)
env.reset()
return env
def _wrap_env(self, env_spec):
if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device)
elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
base_env = env_spec()
return self._wrap_env(base_env)
env = env_spec()
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
else:
raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
return env
def train_step(self, batch):
@ -96,7 +78,7 @@ class Algo(ABC):
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor})
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
action_td = self.prob_actor(td)
action = action_td["action"]

View File

@ -2,9 +2,9 @@ import torch
from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.data.tensor_specs import DiscreteTensorSpec
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy):
def __init__(
@ -41,25 +41,17 @@ class PPO(OnPolicy):
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
# Get spaces from specs for parallel env
obs_space = env.observation_spec
act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.discrete = is_discrete_space(act_space)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete:
distribution_class = torch.distributions.Categorical
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=["logits"],
out_keys=["action"],
)
distribution_kwargs = {"logits": "action_logits"}
else:
if full_covariance:
distribution_class = torch.distributions.MultivariateNormal

View File

@ -1,7 +1,6 @@
import torch
from torch import nn
from typing import Dict, Any, Optional
from torchrl.data.tensor_specs import DiscreteTensorSpec
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector
@ -11,6 +10,7 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
@ -80,7 +80,7 @@ class TRPL(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)

View File

@ -1,17 +1,14 @@
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
from tensordict import TensorDict
class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
self.discrete = isinstance(act_space, DiscreteTensorSpec)
obs_space = obs_space["observation"]
act_space_shape = act_space.shape[1:]
obs_space_shape = obs_space.shape[1:]
self.discrete = is_discrete_space(act_space)
act_space_shape = get_space_shape(act_space)
if self.discrete and full_covariance:
raise ValueError("Full covariance is not applicable for discrete action spaces.")
@ -19,18 +16,18 @@ class Actor(TensorDictModule):
self.full_covariance = full_covariance
if self.discrete:
out_features = act_space_shape[0]
out_keys = ["logits"]
out_features = act_space_shape[-1]
out_keys = ["action_logits"]
else:
if full_covariance:
out_features = act_space_shape[0] + (act_space_shape[0] * (act_space_shape[0] + 1)) // 2
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
out_keys = ["loc", "scale_tril"]
else:
out_features = act_space_shape[0] * 2
out_features = act_space_shape[-1] * 2
out_keys = ["loc", "scale"]
actor_module = MLP(
in_features=obs_space_shape[0],
in_features=get_space_shape(obs_space)[-1],
out_features=out_features,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
@ -39,7 +36,7 @@ class Actor(TensorDictModule):
if not self.discrete:
if full_covariance:
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[0])
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
else:
param_extractor = NormalParamExtractor()
actor_module = nn.Sequential(actor_module, param_extractor)
@ -66,7 +63,7 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP(
in_features=obs_space.shape[-1],
in_features=get_space_shape(obs_space)[-1],
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),

View File

@ -0,0 +1,61 @@
import gymnasium
from gymnasium.spaces import Discrete as GymnasiumDiscrete, MultiDiscrete as GymnasiumMultiDiscrete, MultiBinary as GymnasiumMultiBinary, Box as GymnasiumBox
from torchrl.data.tensor_specs import (
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec,
BinaryDiscreteTensorSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
try:
import gym
from gym.spaces import Discrete as GymDiscrete, MultiDiscrete as GymMultiDiscrete, MultiBinary as GymMultiBinary, Box as GymBox
gym_available = True
except ImportError:
gym_available = False
def is_discrete_space(action_space):
discrete_types = (
GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec
)
continuous_types = (
GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec
)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
return True
elif isinstance(action_space, continuous_types):
return False
else:
raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space):
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
return (action_space.n,)
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
return (sum(action_space.nvec),)
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
return (action_space.n,)
elif gym_available:
if isinstance(action_space, GymDiscrete):
return (action_space.n,)
elif isinstance(action_space, GymMultiDiscrete):
return (sum(action_space.nvec),)
elif isinstance(action_space, GymMultiBinary):
return (action_space.n,)
elif isinstance(action_space, continuous_types):
return action_space.shape
raise ValueError(f"Unsupported action space type: {type(action_space)}")

View File

@ -2,10 +2,9 @@ import pytest
import numpy as np
from fancy_rl import PPO
import gymnasium as gym
from torchrl.envs import GymEnv
def simple_env():
return GymEnv('LunarLander-v2', continuous=True)
return gym.make('LunarLander-v2', continuous=True)
def test_ppo_instantiation():
ppo = PPO(simple_env)
@ -15,10 +14,6 @@ def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_make():
ppo = PPO(gym.make('CartPole-v1'))
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@ -53,12 +48,12 @@ def test_ppo_predict():
def test_ppo_learn():
ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env()
obs = env.reset()
obs, _ = env.reset()
for _ in range(64):
action, _next_state = ppo.predict(obs)
action, _ = ppo.predict(obs)
obs, reward, done, truncated, _ = env.step(action)
if done or truncated:
obs = env.reset()
obs, _ = env.reset()
def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000)
@ -73,10 +68,10 @@ def test_ppo_training():
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs = env.reset()
obs, _ = env.reset()
done = False
while not done:
action, _next_state = policy.predict(obs)
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated