From 8946362336b2cdb4bf2eea4df9302ff4c9b763f4 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 29 May 2024 21:21:43 +0200 Subject: [PATCH] Oh, I could start using git... --- .gitignore | 4 ++ README.md | 53 +++++++++++++++++ example/config.yaml | 25 ++++++++ example/example.py | 37 ++++++++++++ fancy_rl/__init__.py | 6 ++ fancy_rl/loggers.py | 36 ++++++++++++ fancy_rl/on_policy.py | 131 ++++++++++++++++++++++++++++++++++++++++++ fancy_rl/policy.py | 27 +++++++++ fancy_rl/ppo.py | 98 +++++++++++++++++++++++++++++++ fancy_rl/utils.py | 4 ++ setup.py | 19 ++++++ test/test_ppo.py | 54 +++++++++++++++++ 12 files changed, 494 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 example/config.yaml create mode 100644 example/example.py create mode 100644 fancy_rl/__init__.py create mode 100644 fancy_rl/loggers.py create mode 100644 fancy_rl/on_policy.py create mode 100644 fancy_rl/policy.py create mode 100644 fancy_rl/ppo.py create mode 100644 fancy_rl/utils.py create mode 100644 setup.py create mode 100644 test/test_ppo.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6e18e7b --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +.venv +wandb +*.egg-info/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..68a1a2b --- /dev/null +++ b/README.md @@ -0,0 +1,53 @@ +# Fancy RL + +Fancy RL is a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). Future plans include implementing Soft Actor-Critic (SAC). This library focuses on providing clean and understandable code while leveraging the powerful functionalities of torchrl. +We provide optional integration with wandb. + +## Installation + +Fancy RL requires Python 3.7-3.11. (TorchRL currently does not support Python 3.12) + +```bash +pip install -e . +``` + +## Usage + +Here's a basic example of how to train a PPO agent with Fancy RL: + +```python +from fancy_rl.ppo import PPO +from fancy_rl.policy import Policy +import gymnasium as gym + +def env_fn(): + return gym.make("CartPole-v1") + +# Create policy +env = env_fn() +policy = Policy(env.observation_space, env.action_space) + +# Create PPO instance with default config +ppo = PPO(policy=policy, env_fn=env_fn) + +# Train the agent +ppo.train() +``` + +For a more complete function description and advanced usage, refer to `example/example.py`. + +### Testing + +To run the test suite: + +```bash +pytest test/test_ppo.py +``` + +## Contributing + +Contributions are welcome! Feel free to open issues or submit pull requests to enhance the library. + +## License + +This project is licensed under the MIT License. diff --git a/example/config.yaml b/example/config.yaml new file mode 100644 index 0000000..89ca90a --- /dev/null +++ b/example/config.yaml @@ -0,0 +1,25 @@ +policy: + input_dim: 4 + output_dim: 2 + hidden_sizes: [64, 64] + +ppo: + learning_rate: 3e-4 + n_steps: 2048 + batch_size: 64 + n_epochs: 10 + gamma: 0.99 + gae_lambda: 0.95 + clip_range: 0.2 + total_timesteps: 1000000 + eval_interval: 2048 + eval_deterministic: true + eval_episodes: 10 + seed: 42 + +loggers: + - type: terminal + - type: wandb + project: "PPO_project" + entity: "your_entity" + push_interval: 10 diff --git a/example/example.py b/example/example.py new file mode 100644 index 0000000..5baf4ba --- /dev/null +++ b/example/example.py @@ -0,0 +1,37 @@ +import yaml +import torch +from fancy_rl.ppo import PPO +from fancy_rl.policy import Policy +from fancy_rl.loggers import TerminalLogger, WandbLogger +import gymnasium as gym + +def main(config_file): + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + + env_fn = lambda: gym.make("CartPole-v1") + env = env_fn() + + policy_config = config['policy'] + policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes']) + + ppo_config = config['ppo'] + loggers_config = config['loggers'] + + loggers = [] + for logger_config in loggers_config: + logger_type = logger_config.pop('type') + if logger_type == 'terminal': + loggers.append(TerminalLogger(**logger_config)) + elif logger_type == 'wandb': + loggers.append(WandbLogger(**logger_config)) + + ppo = PPO(policy=policy, + env_fn=env_fn, + loggers=loggers, + **ppo_config) + + ppo.train() + +if __name__ == "__main__": + main("example/config.yaml") diff --git a/fancy_rl/__init__.py b/fancy_rl/__init__.py new file mode 100644 index 0000000..efd1a8b --- /dev/null +++ b/fancy_rl/__init__.py @@ -0,0 +1,6 @@ +from fancy_rl.ppo import PPO +from fancy_rl.policy import MLPPolicy +from fancy_rl.loggers import TerminalLogger, WandbLogger +from fancy_rl.utils import make_env + +__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"] diff --git a/fancy_rl/loggers.py b/fancy_rl/loggers.py new file mode 100644 index 0000000..db3e0a9 --- /dev/null +++ b/fancy_rl/loggers.py @@ -0,0 +1,36 @@ +class Logger: + def __init__(self, push_interval=1): + self.data = {} + self.push_interval = push_interval + + def log(self, key, value, epoch): + if key not in self.data: + self.data[key] = [] + self.data[key].append((epoch, value)) + + def end_of_epoch(self, epoch): + if epoch % self.push_interval == 0: + self.push() + + def push(self): + raise NotImplementedError("Push method should be implemented by subclasses") + +class TerminalLogger(Logger): + def push(self): + for key, values in self.data.items(): + for epoch, value in values: + print(f"Epoch {epoch}: {key} = {value}") + self.data = {} + +class WandbLogger(Logger): + def __init__(self, project, entity, config, push_interval=1): + super().__init__(push_interval) + import wandb + self.wandb = wandb + self.wandb.init(project=project, entity=entity, config=config) + + def push(self): + for key, values in self.data.items(): + for epoch, value in values: + self.wandb.log({key: value, 'epoch': epoch}) + self.data = {} diff --git a/fancy_rl/on_policy.py b/fancy_rl/on_policy.py new file mode 100644 index 0000000..7152948 --- /dev/null +++ b/fancy_rl/on_policy.py @@ -0,0 +1,131 @@ +import torch +from abc import ABC, abstractmethod +from fancy_rl.loggers import Logger +from torch.optim import Adam + +class OnPolicy(ABC): + def __init__( + self, + policy, + env_fn, + loggers, + learning_rate, + n_steps, + batch_size, + n_epochs, + gamma, + gae_lambda, + total_timesteps, + eval_interval, + eval_deterministic, + entropy_coef, + critic_coef, + normalize_advantage, + device=None, + **kwargs + ): + self.policy = policy + self.env_fn = env_fn + self.loggers = loggers + self.learning_rate = learning_rate + self.n_steps = n_steps + self.batch_size = batch_size + self.n_epochs = n_epochs + self.gamma = gamma + self.gae_lambda = gae_lambda + self.total_timesteps = total_timesteps + self.eval_interval = eval_interval + self.eval_deterministic = eval_deterministic + self.entropy_coef = entropy_coef + self.critic_coef = critic_coef + self.normalize_advantage = normalize_advantage + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + + self.kwargs = kwargs + self.clip_range = 0.2 + + def train(self): + self.env = self.env_fn() + self.env.reset(seed=self.kwargs.get("seed", None)) + + state = self.env.reset(seed=self.kwargs.get("seed", None)) + episode_return = 0 + episode_length = 0 + for t in range(self.total_timesteps): + rollout = self.collect_rollouts(state) + for batch in self.get_batches(rollout): + loss = self.train_step(batch) + for logger in self.loggers: + logger.log({ + "loss": loss.item() + }, epoch=t) + + if (t + 1) % self.eval_interval == 0: + self.evaluate(t) + + def evaluate(self, epoch): + eval_env = self.env_fn() + eval_env.reset(seed=self.kwargs.get("seed", None)) + returns = [] + for _ in range(self.kwargs.get("eval_episodes", 10)): + state = eval_env.reset(seed=self.kwargs.get("seed", None)) + done = False + total_return = 0 + while not done: + with torch.no_grad(): + action = ( + self.policy.act(state, deterministic=self.eval_deterministic) + if self.eval_deterministic + else self.policy.act(state) + ) + state, reward, done, _ = eval_env.step(action) + total_return += reward + returns.append(total_return) + + avg_return = sum(returns) / len(returns) + for logger in self.loggers: + logger.log({"eval_avg_return": avg_return}, epoch=epoch) + + def collect_rollouts(self, state): + # Collect rollouts logic + rollouts = [] + for _ in range(self.n_steps): + action = self.policy.act(state) + next_state, reward, done, _ = self.env.step(action) + rollouts.append((state, action, reward, next_state, done)) + state = next_state + if done: + state = self.env.reset(seed=self.kwargs.get("seed", None)) + return rollouts + + def get_batches(self, rollouts): + data = self.prepare_data(rollouts) + n_batches = len(data) // self.batch_size + batches = [] + for _ in range(n_batches): + batch_indices = torch.randint(0, len(data), (self.batch_size,)) + batch = data[batch_indices] + batches.append(batch) + return batches + + def prepare_data(self, rollouts): + obs, actions, rewards, next_obs, dones = zip(*rollouts) + obs = torch.tensor(obs, dtype=torch.float32) + actions = torch.tensor(actions, dtype=torch.int64) + rewards = torch.tensor(rewards, dtype=torch.float32) + next_obs = torch.tensor(next_obs, dtype=torch.float32) + dones = torch.tensor(dones, dtype=torch.float32) + + data = { + "obs": obs, + "actions": actions, + "rewards": rewards, + "next_obs": next_obs, + "dones": dones + } + data = self.adv_module(data) + return data + + @abstractmethod + def train_step(self, batch): + pass diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py new file mode 100644 index 0000000..40c6a48 --- /dev/null +++ b/fancy_rl/policy.py @@ -0,0 +1,27 @@ +import torch +from torch import nn + +class Policy(nn.Module): + def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]): + super().__init__() + layers = [] + last_dim = input_dim + for size in hidden_sizes: + layers.append(nn.Linear(last_dim, size)) + layers.append(nn.ReLU()) + last_dim = size + layers.append(nn.Linear(last_dim, output_dim)) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + def act(self, observation, deterministic=False): + with torch.no_grad(): + logits = self.forward(observation) + if deterministic: + action = logits.argmax(dim=-1) + else: + action_dist = torch.distributions.Categorical(logits=logits) + action = action_dist.sample() + return action diff --git a/fancy_rl/ppo.py b/fancy_rl/ppo.py new file mode 100644 index 0000000..7b4abe2 --- /dev/null +++ b/fancy_rl/ppo.py @@ -0,0 +1,98 @@ +import torch +import gymnasium as gym +from fancy_rl.policy import Policy +from fancy_rl.loggers import TerminalLogger +from fancy_rl.on_policy import OnPolicy +from torchrl.objectives import ClipPPOLoss +from torchrl.objectives.value.advantages import GAE + +class PPO(OnPolicy): + def __init__( + self, + policy, + env_fn, + loggers=None, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + total_timesteps=1e6, + eval_interval=2048, + eval_deterministic=True, + entropy_coef=0.01, + critic_coef=0.5, + normalize_advantage=True, + device=None, + clip_epsilon=0.2, + **kwargs + ): + if loggers is None: + loggers = [TerminalLogger(push_interval=1)] + + super().__init__( + policy=policy, + env_fn=env_fn, + loggers=loggers, + learning_rate=learning_rate, + n_steps=n_steps, + batch_size=batch_size, + n_epochs=n_epochs, + gamma=gamma, + gae_lambda=gae_lambda, + total_timesteps=total_timesteps, + eval_interval=eval_interval, + eval_deterministic=eval_deterministic, + entropy_coef=entropy_coef, + critic_coef=critic_coef, + normalize_advantage=normalize_advantage, + device=device, + **kwargs + ) + + self.clip_epsilon = clip_epsilon + self.adv_module = GAE( + gamma=self.gamma, + lmbda=self.gae_lambda, + value_network=self.policy, + average_gae=False, + ) + + self.loss_module = ClipPPOLoss( + actor_network=self.policy, + critic_network=self.policy, + clip_epsilon=self.clip_epsilon, + loss_critic_type='MSELoss', + entropy_coef=self.entropy_coef, + critic_coef=self.critic_coef, + normalize_advantage=self.normalize_advantage, + ) + + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate) + + def train_step(self, batch): + self.optimizer.zero_grad() + loss = self.loss_module(batch) + loss.backward() + self.optimizer.step() + return loss + + def train(self): + self.env = self.env_fn() + self.env.reset(seed=self.kwargs.get("seed", None)) + + state = self.env.reset(seed=self.kwargs.get("seed", None)) + episode_return = 0 + episode_length = 0 + for t in range(self.total_timesteps): + rollout = self.collect_rollouts(state) + for batch in self.get_batches(rollout): + loss = self.train_step(batch) + for logger in self.loggers: + logger.log({ + "loss": loss.item() + }, epoch=t) + + if (t + 1) % self.eval_interval == 0: + self.evaluate(t) diff --git a/fancy_rl/utils.py b/fancy_rl/utils.py new file mode 100644 index 0000000..b08c123 --- /dev/null +++ b/fancy_rl/utils.py @@ -0,0 +1,4 @@ +import gymnasium as gym + +def make_env(env_name): + return lambda: gym.make(env_name) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0dfab7a --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ + +from setuptools import setup, find_packages + +setup( + name="fancy_rl", + version="0.1", + packages=find_packages(), + install_requires=[ + "torch", + "torchrl", + "gymnasium", + "pyyaml", + ], + entry_points={ + "console_scripts": [ + "fancy_rl=fancy_rl.example:main", + ], + }, +) diff --git a/test/test_ppo.py b/test/test_ppo.py new file mode 100644 index 0000000..29c6c90 --- /dev/null +++ b/test/test_ppo.py @@ -0,0 +1,54 @@ +import pytest +import torch +from fancy_rl.ppo import PPO +from fancy_rl.policy import Policy +from fancy_rl.loggers import TerminalLogger +from fancy_rl.utils import make_env + +@pytest.fixture +def policy(): + return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64]) + +@pytest.fixture +def loggers(): + return [TerminalLogger()] + +@pytest.fixture +def env_fn(): + return make_env("CartPole-v1") + +def test_ppo_train(policy, loggers, env_fn): + ppo = PPO(policy=policy, + env_fn=env_fn, + loggers=loggers, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + total_timesteps=10000, + eval_interval=2048, + eval_deterministic=True, + eval_episodes=5, + seed=42) + ppo.train() + +def test_ppo_evaluate(policy, loggers, env_fn): + ppo = PPO(policy=policy, + env_fn=env_fn, + loggers=loggers, + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + total_timesteps=10000, + eval_interval=2048, + eval_deterministic=True, + eval_episodes=5, + seed=42) + ppo.evaluate(epoch=0)