Compare commits

...

4 Commits

Author SHA1 Message Date
5a6069daf4 Showcase TRPL in README 2024-05-31 18:25:42 +02:00
1086c9f6fd Remove all etsts for now (interface changed) 2024-05-31 18:25:17 +02:00
015f1e256a Refactor 2024-05-31 18:25:03 +02:00
0bf748869a Switch to pyproject.toml 2024-05-31 18:24:47 +02:00
8 changed files with 130 additions and 185 deletions

View File

@ -23,11 +23,11 @@ Fancy RL provides two main components:
1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL. 1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL.
```python ```python
from fancy_rl import PPO from fancy_rl import PPO, TRPL
ppo = PPO("CartPole-v1") model = TRPL("CartPole-v1")
ppo.train() model.train()
``` ```
For environments, you can pass any [gymnasium](https://gymnasium.farama.org/) or [Fancy Gym](https://alrhub.github.io/fancy_gym/) environment ID as a string, a function returning a gymnasium environment, or an already instantiated gymnasium environment. Future plans include supporting other torchrl environments. For environments, you can pass any [gymnasium](https://gymnasium.farama.org/) or [Fancy Gym](https://alrhub.github.io/fancy_gym/) environment ID as a string, a function returning a gymnasium environment, or an already instantiated gymnasium environment. Future plans include supporting other torchrl environments.

View File

@ -1,6 +1,9 @@
from fancy_rl.ppo import PPO import gymnasium
from fancy_rl.policy import MLPPolicy try:
from fancy_rl.loggers import TerminalLogger, WandbLogger import fancy_gym
from fancy_rl.utils import make_env except ImportError:
pass
__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"] from fancy_rl.ppo import PPO
__all__ = ["PPO"]

View File

@ -1,18 +1,13 @@
import torch import torch
from abc import ABC, abstractmethod import gymnasium as gym
from torchrl.record.loggers import Logger
from torch.optim import Adam
from torchrl.collectors import SyncDataCollector from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
import gymnasium as gym from abc import ABC, abstractmethod
try:
import fancy_gym
except ImportError:
pass
class OnPolicy(ABC): class OnPolicy(ABC):
def __init__( def __init__(
@ -20,6 +15,7 @@ class OnPolicy(ABC):
policy, policy,
env_spec, env_spec,
loggers, loggers,
optimizers,
learning_rate, learning_rate,
n_steps, n_steps,
batch_size, batch_size,
@ -41,6 +37,7 @@ class OnPolicy(ABC):
self.env_spec = env_spec self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers self.loggers = loggers
self.optimizers = optimizers
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.n_steps = n_steps self.n_steps = n_steps
self.batch_size = batch_size self.batch_size = batch_size
@ -90,6 +87,15 @@ class OnPolicy(ABC):
raise ValueError("env_spec must be a string or a callable that returns an environment.") raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env return env
def train_step(self, batch):
for optimizer in self.optimizers.values():
optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
for optimizer in self.optimizers.values():
optimizer.step()
return loss
def train(self): def train(self):
collected_frames = 0 collected_frames = 0
@ -136,10 +142,6 @@ class OnPolicy(ABC):
for logger in self.loggers: for logger in self.loggers:
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
@abstractmethod
def train_step(self, batch):
pass
def dump_video(module): def dump_video(module):
if isinstance(module, VideoRecorder): if isinstance(module, VideoRecorder):
module.dump() module.dump()

View File

@ -1,71 +1,59 @@
import torch import torch.nn as nn
from torch import nn from tensordict.nn import TensorDictModule
from torch.distributions import Categorical, Normal from torchrl.modules import MLP
import gymnasium as gym from tensordict.nn.distributions import NormalParamExtractor
class Actor(nn.Module): class SharedModule(TensorDictModule):
def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
super().__init__() if hidden_sizes:
self.continuous = isinstance(action_space, gym.spaces.Box) shared_module = MLP(
input_dim = observation_space.shape[-1] in_features=obs_space.shape[-1],
if self.continuous: out_features=hidden_sizes[-1],
output_dim = action_space.shape[-1] num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
)
out_features = hidden_sizes[-1]
else: else:
output_dim = action_space.n shared_module = nn.Identity()
out_features = obs_space.shape[-1]
layers = [] super().__init__(
last_dim = input_dim module=shared_module,
for size in hidden_sizes: in_keys=["observation"],
layers.append(nn.Linear(last_dim, size)) out_keys=["shared"],
layers.append(activation_fn()) )
last_dim = size self.out_features = out_features
if self.continuous: class Actor(TensorDictModule):
self.mu_layer = nn.Linear(last_dim, output_dim) def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
self.log_std_layer = nn.Linear(last_dim, output_dim) actor_module = nn.Sequential(
else: MLP(
layers.append(nn.Linear(last_dim, output_dim)) in_features=shared_module.out_features,
self.model = nn.Sequential(*layers) out_features=act_space.shape[-1] * 2,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
),
NormalParamExtractor(),
).to(device)
super().__init__(
module=actor_module,
in_keys=["shared"],
out_keys=["loc", "scale"],
)
def forward(self, x): class Critic(TensorDictModule):
if self.continuous: def __init__(self, shared_module, hidden_sizes, activation_fn, device):
mu = self.mu_layer(x) critic_module = MLP(
log_std = self.log_std_layer(x) in_features=shared_module.out_features,
return mu, log_std.exp() out_features=1,
else: num_cells=hidden_sizes,
return self.model(x) activation_class=getattr(nn, activation_fn),
device=device
def act(self, observation, deterministic=False): ).to(device)
with torch.no_grad(): super().__init__(
if self.continuous: module=critic_module,
mu, std = self.forward(observation) in_keys=["shared"],
if deterministic: out_keys=["state_value"],
action = mu )
else:
action_dist = Normal(mu, std)
action = action_dist.sample()
else:
logits = self.forward(observation)
if deterministic:
action = logits.argmax(dim=-1)
else:
action_dist = Categorical(logits=logits)
action = action_dist.sample()
return action
class Critic(nn.Module):
def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
super().__init__()
input_dim = observation_space.shape[-1]
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(activation_fn())
last_dim = size
layers.append(nn.Linear(last_dim, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x).squeeze(-1)

View File

@ -1,11 +1,9 @@
import torch import torch
import torch.nn as nn from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import get_logger from fancy_rl.on_policy import OnPolicy
from on_policy import OnPolicy from fancy_rl.policy import Actor, Critic, SharedModule
from policy import Actor, Critic
import gymnasium as gym
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(
@ -14,8 +12,9 @@ class PPO(OnPolicy):
loggers=None, loggers=None,
actor_hidden_sizes=[64, 64], actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="ReLU", actor_activation_fn="Tanh",
critic_activation_fn="ReLU", critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,
@ -33,21 +32,45 @@ class PPO(OnPolicy):
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
): ):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
env = self.make_env(env_spec) self.env_spec = env_spec
env = self.make_env()
obs_space = env.observation_space obs_space = env.observation_space
act_space = env.action_space act_space = env.action_space
actor_activation_fn = getattr(nn, actor_activation_fn) # Define the shared, actor, and critic modules
critic_activation_fn = getattr(nn, critic_activation_fn) self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn) # Combine into an ActorValueOperator
self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn) self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
self.policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
}
super().__init__( super().__init__(
policy=self.actor, policy=self.policy,
env_spec=env_spec, env_spec=env_spec,
loggers=loggers, loggers=loggers,
optimizers=optimizers,
learning_rate=learning_rate, learning_rate=learning_rate,
n_steps=n_steps, n_steps=n_steps,
batch_size=batch_size, batch_size=batch_size,
@ -82,15 +105,3 @@ class PPO(OnPolicy):
critic_coef=self.critic_coef, critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage, normalize_advantage=self.normalize_advantage,
) )
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
def train_step(self, batch):
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
return loss

13
pyproject.toml Normal file
View File

@ -0,0 +1,13 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "fancy_rl"
version = "0.1.0"
dependencies = [
"gymnasium",
"pyyaml",
"torch",
"torchrl"
]

View File

@ -1,19 +0,0 @@
from setuptools import setup, find_packages
setup(
name="fancy_rl",
version="0.1",
packages=find_packages(),
install_requires=[
"torch",
"torchrl",
"gymnasium",
"pyyaml",
],
entry_points={
"console_scripts": [
"fancy_rl=fancy_rl.example:main",
],
},
)

View File

@ -1,54 +1 @@
import pytest # TODO
import torch
from fancy_rl.ppo import PPO
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger
from fancy_rl.utils import make_env
@pytest.fixture
def policy():
return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64])
@pytest.fixture
def loggers():
return [TerminalLogger()]
@pytest.fixture
def env_fn():
return make_env("CartPole-v1")
def test_ppo_train(policy, loggers, env_fn):
ppo = PPO(policy=policy,
env_fn=env_fn,
loggers=loggers,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
total_timesteps=10000,
eval_interval=2048,
eval_deterministic=True,
eval_episodes=5,
seed=42)
ppo.train()
def test_ppo_evaluate(policy, loggers, env_fn):
ppo = PPO(policy=policy,
env_fn=env_fn,
loggers=loggers,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
total_timesteps=10000,
eval_interval=2048,
eval_deterministic=True,
eval_episodes=5,
seed=42)
ppo.evaluate(epoch=0)