Oh, I could start using git...

This commit is contained in:
Dominik Moritz Roth 2024-05-29 21:21:43 +02:00
commit 8946362336
12 changed files with 494 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
__pycache__
.venv
wandb
*.egg-info/

53
README.md Normal file
View File

@ -0,0 +1,53 @@
# Fancy RL
Fancy RL is a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). Future plans include implementing Soft Actor-Critic (SAC). This library focuses on providing clean and understandable code while leveraging the powerful functionalities of torchrl.
We provide optional integration with wandb.
## Installation
Fancy RL requires Python 3.7-3.11. (TorchRL currently does not support Python 3.12)
```bash
pip install -e .
```
## Usage
Here's a basic example of how to train a PPO agent with Fancy RL:
```python
from fancy_rl.ppo import PPO
from fancy_rl.policy import Policy
import gymnasium as gym
def env_fn():
return gym.make("CartPole-v1")
# Create policy
env = env_fn()
policy = Policy(env.observation_space, env.action_space)
# Create PPO instance with default config
ppo = PPO(policy=policy, env_fn=env_fn)
# Train the agent
ppo.train()
```
For a more complete function description and advanced usage, refer to `example/example.py`.
### Testing
To run the test suite:
```bash
pytest test/test_ppo.py
```
## Contributing
Contributions are welcome! Feel free to open issues or submit pull requests to enhance the library.
## License
This project is licensed under the MIT License.

25
example/config.yaml Normal file
View File

@ -0,0 +1,25 @@
policy:
input_dim: 4
output_dim: 2
hidden_sizes: [64, 64]
ppo:
learning_rate: 3e-4
n_steps: 2048
batch_size: 64
n_epochs: 10
gamma: 0.99
gae_lambda: 0.95
clip_range: 0.2
total_timesteps: 1000000
eval_interval: 2048
eval_deterministic: true
eval_episodes: 10
seed: 42
loggers:
- type: terminal
- type: wandb
project: "PPO_project"
entity: "your_entity"
push_interval: 10

37
example/example.py Normal file
View File

@ -0,0 +1,37 @@
import yaml
import torch
from fancy_rl.ppo import PPO
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger, WandbLogger
import gymnasium as gym
def main(config_file):
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
env_fn = lambda: gym.make("CartPole-v1")
env = env_fn()
policy_config = config['policy']
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
ppo_config = config['ppo']
loggers_config = config['loggers']
loggers = []
for logger_config in loggers_config:
logger_type = logger_config.pop('type')
if logger_type == 'terminal':
loggers.append(TerminalLogger(**logger_config))
elif logger_type == 'wandb':
loggers.append(WandbLogger(**logger_config))
ppo = PPO(policy=policy,
env_fn=env_fn,
loggers=loggers,
**ppo_config)
ppo.train()
if __name__ == "__main__":
main("example/config.yaml")

6
fancy_rl/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from fancy_rl.ppo import PPO
from fancy_rl.policy import MLPPolicy
from fancy_rl.loggers import TerminalLogger, WandbLogger
from fancy_rl.utils import make_env
__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"]

36
fancy_rl/loggers.py Normal file
View File

@ -0,0 +1,36 @@
class Logger:
def __init__(self, push_interval=1):
self.data = {}
self.push_interval = push_interval
def log(self, key, value, epoch):
if key not in self.data:
self.data[key] = []
self.data[key].append((epoch, value))
def end_of_epoch(self, epoch):
if epoch % self.push_interval == 0:
self.push()
def push(self):
raise NotImplementedError("Push method should be implemented by subclasses")
class TerminalLogger(Logger):
def push(self):
for key, values in self.data.items():
for epoch, value in values:
print(f"Epoch {epoch}: {key} = {value}")
self.data = {}
class WandbLogger(Logger):
def __init__(self, project, entity, config, push_interval=1):
super().__init__(push_interval)
import wandb
self.wandb = wandb
self.wandb.init(project=project, entity=entity, config=config)
def push(self):
for key, values in self.data.items():
for epoch, value in values:
self.wandb.log({key: value, 'epoch': epoch})
self.data = {}

131
fancy_rl/on_policy.py Normal file
View File

@ -0,0 +1,131 @@
import torch
from abc import ABC, abstractmethod
from fancy_rl.loggers import Logger
from torch.optim import Adam
class OnPolicy(ABC):
def __init__(
self,
policy,
env_fn,
loggers,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
gae_lambda,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
device=None,
**kwargs
):
self.policy = policy
self.env_fn = env_fn
self.loggers = loggers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.gae_lambda = gae_lambda
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.kwargs = kwargs
self.clip_range = 0.2
def train(self):
self.env = self.env_fn()
self.env.reset(seed=self.kwargs.get("seed", None))
state = self.env.reset(seed=self.kwargs.get("seed", None))
episode_return = 0
episode_length = 0
for t in range(self.total_timesteps):
rollout = self.collect_rollouts(state)
for batch in self.get_batches(rollout):
loss = self.train_step(batch)
for logger in self.loggers:
logger.log({
"loss": loss.item()
}, epoch=t)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)
def evaluate(self, epoch):
eval_env = self.env_fn()
eval_env.reset(seed=self.kwargs.get("seed", None))
returns = []
for _ in range(self.kwargs.get("eval_episodes", 10)):
state = eval_env.reset(seed=self.kwargs.get("seed", None))
done = False
total_return = 0
while not done:
with torch.no_grad():
action = (
self.policy.act(state, deterministic=self.eval_deterministic)
if self.eval_deterministic
else self.policy.act(state)
)
state, reward, done, _ = eval_env.step(action)
total_return += reward
returns.append(total_return)
avg_return = sum(returns) / len(returns)
for logger in self.loggers:
logger.log({"eval_avg_return": avg_return}, epoch=epoch)
def collect_rollouts(self, state):
# Collect rollouts logic
rollouts = []
for _ in range(self.n_steps):
action = self.policy.act(state)
next_state, reward, done, _ = self.env.step(action)
rollouts.append((state, action, reward, next_state, done))
state = next_state
if done:
state = self.env.reset(seed=self.kwargs.get("seed", None))
return rollouts
def get_batches(self, rollouts):
data = self.prepare_data(rollouts)
n_batches = len(data) // self.batch_size
batches = []
for _ in range(n_batches):
batch_indices = torch.randint(0, len(data), (self.batch_size,))
batch = data[batch_indices]
batches.append(batch)
return batches
def prepare_data(self, rollouts):
obs, actions, rewards, next_obs, dones = zip(*rollouts)
obs = torch.tensor(obs, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.int64)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_obs = torch.tensor(next_obs, dtype=torch.float32)
dones = torch.tensor(dones, dtype=torch.float32)
data = {
"obs": obs,
"actions": actions,
"rewards": rewards,
"next_obs": next_obs,
"dones": dones
}
data = self.adv_module(data)
return data
@abstractmethod
def train_step(self, batch):
pass

27
fancy_rl/policy.py Normal file
View File

@ -0,0 +1,27 @@
import torch
from torch import nn
class Policy(nn.Module):
def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]):
super().__init__()
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(nn.ReLU())
last_dim = size
layers.append(nn.Linear(last_dim, output_dim))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
def act(self, observation, deterministic=False):
with torch.no_grad():
logits = self.forward(observation)
if deterministic:
action = logits.argmax(dim=-1)
else:
action_dist = torch.distributions.Categorical(logits=logits)
action = action_dist.sample()
return action

98
fancy_rl/ppo.py Normal file
View File

@ -0,0 +1,98 @@
import torch
import gymnasium as gym
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger
from fancy_rl.on_policy import OnPolicy
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
class PPO(OnPolicy):
def __init__(
self,
policy,
env_fn,
loggers=None,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
device=None,
clip_epsilon=0.2,
**kwargs
):
if loggers is None:
loggers = [TerminalLogger(push_interval=1)]
super().__init__(
policy=policy,
env_fn=env_fn,
loggers=loggers,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
gae_lambda=gae_lambda,
total_timesteps=total_timesteps,
eval_interval=eval_interval,
eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
device=device,
**kwargs
)
self.clip_epsilon = clip_epsilon
self.adv_module = GAE(
gamma=self.gamma,
lmbda=self.gae_lambda,
value_network=self.policy,
average_gae=False,
)
self.loss_module = ClipPPOLoss(
actor_network=self.policy,
critic_network=self.policy,
clip_epsilon=self.clip_epsilon,
loss_critic_type='MSELoss',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate)
def train_step(self, batch):
self.optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
self.optimizer.step()
return loss
def train(self):
self.env = self.env_fn()
self.env.reset(seed=self.kwargs.get("seed", None))
state = self.env.reset(seed=self.kwargs.get("seed", None))
episode_return = 0
episode_length = 0
for t in range(self.total_timesteps):
rollout = self.collect_rollouts(state)
for batch in self.get_batches(rollout):
loss = self.train_step(batch)
for logger in self.loggers:
logger.log({
"loss": loss.item()
}, epoch=t)
if (t + 1) % self.eval_interval == 0:
self.evaluate(t)

4
fancy_rl/utils.py Normal file
View File

@ -0,0 +1,4 @@
import gymnasium as gym
def make_env(env_name):
return lambda: gym.make(env_name)

19
setup.py Normal file
View File

@ -0,0 +1,19 @@
from setuptools import setup, find_packages
setup(
name="fancy_rl",
version="0.1",
packages=find_packages(),
install_requires=[
"torch",
"torchrl",
"gymnasium",
"pyyaml",
],
entry_points={
"console_scripts": [
"fancy_rl=fancy_rl.example:main",
],
},
)

54
test/test_ppo.py Normal file
View File

@ -0,0 +1,54 @@
import pytest
import torch
from fancy_rl.ppo import PPO
from fancy_rl.policy import Policy
from fancy_rl.loggers import TerminalLogger
from fancy_rl.utils import make_env
@pytest.fixture
def policy():
return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64])
@pytest.fixture
def loggers():
return [TerminalLogger()]
@pytest.fixture
def env_fn():
return make_env("CartPole-v1")
def test_ppo_train(policy, loggers, env_fn):
ppo = PPO(policy=policy,
env_fn=env_fn,
loggers=loggers,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
total_timesteps=10000,
eval_interval=2048,
eval_deterministic=True,
eval_episodes=5,
seed=42)
ppo.train()
def test_ppo_evaluate(policy, loggers, env_fn):
ppo = PPO(policy=policy,
env_fn=env_fn,
loggers=loggers,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
total_timesteps=10000,
eval_interval=2048,
eval_deterministic=True,
eval_episodes=5,
seed=42)
ppo.evaluate(epoch=0)