Compare commits
4 Commits
bd507c37c3
...
5a6069daf4
Author | SHA1 | Date | |
---|---|---|---|
5a6069daf4 | |||
1086c9f6fd | |||
015f1e256a | |||
0bf748869a |
@ -23,11 +23,11 @@ Fancy RL provides two main components:
|
|||||||
1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL.
|
1. **Ready-to-use Classes for PPO / TRPL**: These classes allow you to quickly get started with reinforcement learning algorithms, enjoying the performance and hackability that comes with using TorchRL.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from fancy_rl import PPO
|
from fancy_rl import PPO, TRPL
|
||||||
|
|
||||||
ppo = PPO("CartPole-v1")
|
model = TRPL("CartPole-v1")
|
||||||
|
|
||||||
ppo.train()
|
model.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
For environments, you can pass any [gymnasium](https://gymnasium.farama.org/) or [Fancy Gym](https://alrhub.github.io/fancy_gym/) environment ID as a string, a function returning a gymnasium environment, or an already instantiated gymnasium environment. Future plans include supporting other torchrl environments.
|
For environments, you can pass any [gymnasium](https://gymnasium.farama.org/) or [Fancy Gym](https://alrhub.github.io/fancy_gym/) environment ID as a string, a function returning a gymnasium environment, or an already instantiated gymnasium environment. Future plans include supporting other torchrl environments.
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
from fancy_rl.ppo import PPO
|
import gymnasium
|
||||||
from fancy_rl.policy import MLPPolicy
|
try:
|
||||||
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
import fancy_gym
|
||||||
from fancy_rl.utils import make_env
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"]
|
from fancy_rl.ppo import PPO
|
||||||
|
|
||||||
|
__all__ = ["PPO"]
|
@ -1,18 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
import gymnasium as gym
|
||||||
from torchrl.record.loggers import Logger
|
|
||||||
from torch.optim import Adam
|
|
||||||
from torchrl.collectors import SyncDataCollector
|
from torchrl.collectors import SyncDataCollector
|
||||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||||
from torchrl.envs import ExplorationType, set_exploration_type
|
|
||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
|
from torchrl.envs import ExplorationType, set_exploration_type
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
import gymnasium as gym
|
from abc import ABC, abstractmethod
|
||||||
try:
|
|
||||||
import fancy_gym
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class OnPolicy(ABC):
|
class OnPolicy(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -20,6 +15,7 @@ class OnPolicy(ABC):
|
|||||||
policy,
|
policy,
|
||||||
env_spec,
|
env_spec,
|
||||||
loggers,
|
loggers,
|
||||||
|
optimizers,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
n_steps,
|
n_steps,
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -41,6 +37,7 @@ class OnPolicy(ABC):
|
|||||||
self.env_spec = env_spec
|
self.env_spec = env_spec
|
||||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||||
self.loggers = loggers
|
self.loggers = loggers
|
||||||
|
self.optimizers = optimizers
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.n_steps = n_steps
|
self.n_steps = n_steps
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@ -90,6 +87,15 @@ class OnPolicy(ABC):
|
|||||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
def train_step(self, batch):
|
||||||
|
for optimizer in self.optimizers.values():
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = self.loss_module(batch)
|
||||||
|
loss.backward()
|
||||||
|
for optimizer in self.optimizers.values():
|
||||||
|
optimizer.step()
|
||||||
|
return loss
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
collected_frames = 0
|
collected_frames = 0
|
||||||
|
|
||||||
@ -136,10 +142,6 @@ class OnPolicy(ABC):
|
|||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train_step(self, batch):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def dump_video(module):
|
def dump_video(module):
|
||||||
if isinstance(module, VideoRecorder):
|
if isinstance(module, VideoRecorder):
|
||||||
module.dump()
|
module.dump()
|
||||||
|
@ -1,71 +1,59 @@
|
|||||||
import torch
|
import torch.nn as nn
|
||||||
from torch import nn
|
from tensordict.nn import TensorDictModule
|
||||||
from torch.distributions import Categorical, Normal
|
from torchrl.modules import MLP
|
||||||
import gymnasium as gym
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class SharedModule(TensorDictModule):
|
||||||
def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
super().__init__()
|
if hidden_sizes:
|
||||||
self.continuous = isinstance(action_space, gym.spaces.Box)
|
shared_module = MLP(
|
||||||
input_dim = observation_space.shape[-1]
|
in_features=obs_space.shape[-1],
|
||||||
if self.continuous:
|
out_features=hidden_sizes[-1],
|
||||||
output_dim = action_space.shape[-1]
|
num_cells=hidden_sizes,
|
||||||
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
out_features = hidden_sizes[-1]
|
||||||
else:
|
else:
|
||||||
output_dim = action_space.n
|
shared_module = nn.Identity()
|
||||||
|
out_features = obs_space.shape[-1]
|
||||||
|
|
||||||
layers = []
|
super().__init__(
|
||||||
last_dim = input_dim
|
module=shared_module,
|
||||||
for size in hidden_sizes:
|
in_keys=["observation"],
|
||||||
layers.append(nn.Linear(last_dim, size))
|
out_keys=["shared"],
|
||||||
layers.append(activation_fn())
|
)
|
||||||
last_dim = size
|
self.out_features = out_features
|
||||||
|
|
||||||
if self.continuous:
|
|
||||||
self.mu_layer = nn.Linear(last_dim, output_dim)
|
|
||||||
self.log_std_layer = nn.Linear(last_dim, output_dim)
|
|
||||||
else:
|
|
||||||
layers.append(nn.Linear(last_dim, output_dim))
|
|
||||||
self.model = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
class Actor(TensorDictModule):
|
||||||
if self.continuous:
|
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
|
||||||
mu = self.mu_layer(x)
|
actor_module = nn.Sequential(
|
||||||
log_std = self.log_std_layer(x)
|
MLP(
|
||||||
return mu, log_std.exp()
|
in_features=shared_module.out_features,
|
||||||
else:
|
out_features=act_space.shape[-1] * 2,
|
||||||
return self.model(x)
|
num_cells=hidden_sizes,
|
||||||
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
device=device
|
||||||
|
),
|
||||||
|
NormalParamExtractor(),
|
||||||
|
).to(device)
|
||||||
|
super().__init__(
|
||||||
|
module=actor_module,
|
||||||
|
in_keys=["shared"],
|
||||||
|
out_keys=["loc", "scale"],
|
||||||
|
)
|
||||||
|
|
||||||
def act(self, observation, deterministic=False):
|
class Critic(TensorDictModule):
|
||||||
with torch.no_grad():
|
def __init__(self, shared_module, hidden_sizes, activation_fn, device):
|
||||||
if self.continuous:
|
critic_module = MLP(
|
||||||
mu, std = self.forward(observation)
|
in_features=shared_module.out_features,
|
||||||
if deterministic:
|
out_features=1,
|
||||||
action = mu
|
num_cells=hidden_sizes,
|
||||||
else:
|
activation_class=getattr(nn, activation_fn),
|
||||||
action_dist = Normal(mu, std)
|
device=device
|
||||||
action = action_dist.sample()
|
).to(device)
|
||||||
else:
|
super().__init__(
|
||||||
logits = self.forward(observation)
|
module=critic_module,
|
||||||
if deterministic:
|
in_keys=["shared"],
|
||||||
action = logits.argmax(dim=-1)
|
out_keys=["state_value"],
|
||||||
else:
|
)
|
||||||
action_dist = Categorical(logits=logits)
|
|
||||||
action = action_dist.sample()
|
|
||||||
return action
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
|
|
||||||
super().__init__()
|
|
||||||
input_dim = observation_space.shape[-1]
|
|
||||||
|
|
||||||
layers = []
|
|
||||||
last_dim = input_dim
|
|
||||||
for size in hidden_sizes:
|
|
||||||
layers.append(nn.Linear(last_dim, size))
|
|
||||||
layers.append(activation_fn())
|
|
||||||
last_dim = size
|
|
||||||
layers.append(nn.Linear(last_dim, 1))
|
|
||||||
self.model = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.model(x).squeeze(-1)
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
from torchrl.modules import ActorValueOperator, ProbabilisticActor
|
||||||
from torchrl.objectives import ClipPPOLoss
|
from torchrl.objectives import ClipPPOLoss
|
||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
from torchrl.record.loggers import get_logger
|
from fancy_rl.on_policy import OnPolicy
|
||||||
from on_policy import OnPolicy
|
from fancy_rl.policy import Actor, Critic, SharedModule
|
||||||
from policy import Actor, Critic
|
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -14,8 +12,9 @@ class PPO(OnPolicy):
|
|||||||
loggers=None,
|
loggers=None,
|
||||||
actor_hidden_sizes=[64, 64],
|
actor_hidden_sizes=[64, 64],
|
||||||
critic_hidden_sizes=[64, 64],
|
critic_hidden_sizes=[64, 64],
|
||||||
actor_activation_fn="ReLU",
|
actor_activation_fn="Tanh",
|
||||||
critic_activation_fn="ReLU",
|
critic_activation_fn="Tanh",
|
||||||
|
shared_stem_sizes=[64],
|
||||||
learning_rate=3e-4,
|
learning_rate=3e-4,
|
||||||
n_steps=2048,
|
n_steps=2048,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
@ -33,21 +32,45 @@ class PPO(OnPolicy):
|
|||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
):
|
):
|
||||||
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Initialize environment to get observation and action space sizes
|
# Initialize environment to get observation and action space sizes
|
||||||
env = self.make_env(env_spec)
|
self.env_spec = env_spec
|
||||||
|
env = self.make_env()
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
actor_activation_fn = getattr(nn, actor_activation_fn)
|
# Define the shared, actor, and critic modules
|
||||||
critic_activation_fn = getattr(nn, critic_activation_fn)
|
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
|
||||||
|
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||||
|
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
|
|
||||||
self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn)
|
# Combine into an ActorValueOperator
|
||||||
self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn)
|
self.ac_module = ActorValueOperator(
|
||||||
|
self.shared_module,
|
||||||
|
self.actor,
|
||||||
|
self.critic
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the policy as a ProbabilisticActor
|
||||||
|
self.policy = ProbabilisticActor(
|
||||||
|
module=self.ac_module.get_policy_operator(),
|
||||||
|
in_keys=["loc", "scale"],
|
||||||
|
out_keys=["action"],
|
||||||
|
distribution_class=torch.distributions.Normal,
|
||||||
|
return_log_prob=True
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizers = {
|
||||||
|
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||||
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
|
}
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=self.actor,
|
policy=self.policy,
|
||||||
env_spec=env_spec,
|
env_spec=env_spec,
|
||||||
loggers=loggers,
|
loggers=loggers,
|
||||||
|
optimizers=optimizers,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -82,15 +105,3 @@ class PPO(OnPolicy):
|
|||||||
critic_coef=self.critic_coef,
|
critic_coef=self.critic_coef,
|
||||||
normalize_advantage=self.normalize_advantage,
|
normalize_advantage=self.normalize_advantage,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
|
|
||||||
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
|
|
||||||
|
|
||||||
def train_step(self, batch):
|
|
||||||
self.actor_optimizer.zero_grad()
|
|
||||||
self.critic_optimizer.zero_grad()
|
|
||||||
loss = self.loss_module(batch)
|
|
||||||
loss.backward()
|
|
||||||
self.actor_optimizer.step()
|
|
||||||
self.critic_optimizer.step()
|
|
||||||
return loss
|
|
||||||
|
13
pyproject.toml
Normal file
13
pyproject.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "fancy_rl"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"gymnasium",
|
||||||
|
"pyyaml",
|
||||||
|
"torch",
|
||||||
|
"torchrl"
|
||||||
|
]
|
19
setup.py
19
setup.py
@ -1,19 +0,0 @@
|
|||||||
|
|
||||||
from setuptools import setup, find_packages
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="fancy_rl",
|
|
||||||
version="0.1",
|
|
||||||
packages=find_packages(),
|
|
||||||
install_requires=[
|
|
||||||
"torch",
|
|
||||||
"torchrl",
|
|
||||||
"gymnasium",
|
|
||||||
"pyyaml",
|
|
||||||
],
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": [
|
|
||||||
"fancy_rl=fancy_rl.example:main",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
|
@ -1,54 +1 @@
|
|||||||
import pytest
|
# TODO
|
||||||
import torch
|
|
||||||
from fancy_rl.ppo import PPO
|
|
||||||
from fancy_rl.policy import Policy
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
|
||||||
from fancy_rl.utils import make_env
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def policy():
|
|
||||||
return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64])
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def loggers():
|
|
||||||
return [TerminalLogger()]
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def env_fn():
|
|
||||||
return make_env("CartPole-v1")
|
|
||||||
|
|
||||||
def test_ppo_train(policy, loggers, env_fn):
|
|
||||||
ppo = PPO(policy=policy,
|
|
||||||
env_fn=env_fn,
|
|
||||||
loggers=loggers,
|
|
||||||
learning_rate=3e-4,
|
|
||||||
n_steps=2048,
|
|
||||||
batch_size=64,
|
|
||||||
n_epochs=10,
|
|
||||||
gamma=0.99,
|
|
||||||
gae_lambda=0.95,
|
|
||||||
clip_range=0.2,
|
|
||||||
total_timesteps=10000,
|
|
||||||
eval_interval=2048,
|
|
||||||
eval_deterministic=True,
|
|
||||||
eval_episodes=5,
|
|
||||||
seed=42)
|
|
||||||
ppo.train()
|
|
||||||
|
|
||||||
def test_ppo_evaluate(policy, loggers, env_fn):
|
|
||||||
ppo = PPO(policy=policy,
|
|
||||||
env_fn=env_fn,
|
|
||||||
loggers=loggers,
|
|
||||||
learning_rate=3e-4,
|
|
||||||
n_steps=2048,
|
|
||||||
batch_size=64,
|
|
||||||
n_epochs=10,
|
|
||||||
gamma=0.99,
|
|
||||||
gae_lambda=0.95,
|
|
||||||
clip_range=0.2,
|
|
||||||
total_timesteps=10000,
|
|
||||||
eval_interval=2048,
|
|
||||||
eval_deterministic=True,
|
|
||||||
eval_episodes=5,
|
|
||||||
seed=42)
|
|
||||||
ppo.evaluate(epoch=0)
|
|
Loading…
Reference in New Issue
Block a user