Oh, I could start using git...
This commit is contained in:
commit
8946362336
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
__pycache__
|
||||||
|
.venv
|
||||||
|
wandb
|
||||||
|
*.egg-info/
|
53
README.md
Normal file
53
README.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Fancy RL
|
||||||
|
|
||||||
|
Fancy RL is a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). Future plans include implementing Soft Actor-Critic (SAC). This library focuses on providing clean and understandable code while leveraging the powerful functionalities of torchrl.
|
||||||
|
We provide optional integration with wandb.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Fancy RL requires Python 3.7-3.11. (TorchRL currently does not support Python 3.12)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Here's a basic example of how to train a PPO agent with Fancy RL:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fancy_rl.ppo import PPO
|
||||||
|
from fancy_rl.policy import Policy
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
def env_fn():
|
||||||
|
return gym.make("CartPole-v1")
|
||||||
|
|
||||||
|
# Create policy
|
||||||
|
env = env_fn()
|
||||||
|
policy = Policy(env.observation_space, env.action_space)
|
||||||
|
|
||||||
|
# Create PPO instance with default config
|
||||||
|
ppo = PPO(policy=policy, env_fn=env_fn)
|
||||||
|
|
||||||
|
# Train the agent
|
||||||
|
ppo.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
For a more complete function description and advanced usage, refer to `example/example.py`.
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
To run the test suite:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest test/test_ppo.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Feel free to open issues or submit pull requests to enhance the library.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the MIT License.
|
25
example/config.yaml
Normal file
25
example/config.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
policy:
|
||||||
|
input_dim: 4
|
||||||
|
output_dim: 2
|
||||||
|
hidden_sizes: [64, 64]
|
||||||
|
|
||||||
|
ppo:
|
||||||
|
learning_rate: 3e-4
|
||||||
|
n_steps: 2048
|
||||||
|
batch_size: 64
|
||||||
|
n_epochs: 10
|
||||||
|
gamma: 0.99
|
||||||
|
gae_lambda: 0.95
|
||||||
|
clip_range: 0.2
|
||||||
|
total_timesteps: 1000000
|
||||||
|
eval_interval: 2048
|
||||||
|
eval_deterministic: true
|
||||||
|
eval_episodes: 10
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
loggers:
|
||||||
|
- type: terminal
|
||||||
|
- type: wandb
|
||||||
|
project: "PPO_project"
|
||||||
|
entity: "your_entity"
|
||||||
|
push_interval: 10
|
37
example/example.py
Normal file
37
example/example.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
from fancy_rl.ppo import PPO
|
||||||
|
from fancy_rl.policy import Policy
|
||||||
|
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
def main(config_file):
|
||||||
|
with open(config_file, 'r') as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
env_fn = lambda: gym.make("CartPole-v1")
|
||||||
|
env = env_fn()
|
||||||
|
|
||||||
|
policy_config = config['policy']
|
||||||
|
policy = Policy(env=env, hidden_sizes=policy_config['hidden_sizes'])
|
||||||
|
|
||||||
|
ppo_config = config['ppo']
|
||||||
|
loggers_config = config['loggers']
|
||||||
|
|
||||||
|
loggers = []
|
||||||
|
for logger_config in loggers_config:
|
||||||
|
logger_type = logger_config.pop('type')
|
||||||
|
if logger_type == 'terminal':
|
||||||
|
loggers.append(TerminalLogger(**logger_config))
|
||||||
|
elif logger_type == 'wandb':
|
||||||
|
loggers.append(WandbLogger(**logger_config))
|
||||||
|
|
||||||
|
ppo = PPO(policy=policy,
|
||||||
|
env_fn=env_fn,
|
||||||
|
loggers=loggers,
|
||||||
|
**ppo_config)
|
||||||
|
|
||||||
|
ppo.train()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main("example/config.yaml")
|
6
fancy_rl/__init__.py
Normal file
6
fancy_rl/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from fancy_rl.ppo import PPO
|
||||||
|
from fancy_rl.policy import MLPPolicy
|
||||||
|
from fancy_rl.loggers import TerminalLogger, WandbLogger
|
||||||
|
from fancy_rl.utils import make_env
|
||||||
|
|
||||||
|
__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"]
|
36
fancy_rl/loggers.py
Normal file
36
fancy_rl/loggers.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
class Logger:
|
||||||
|
def __init__(self, push_interval=1):
|
||||||
|
self.data = {}
|
||||||
|
self.push_interval = push_interval
|
||||||
|
|
||||||
|
def log(self, key, value, epoch):
|
||||||
|
if key not in self.data:
|
||||||
|
self.data[key] = []
|
||||||
|
self.data[key].append((epoch, value))
|
||||||
|
|
||||||
|
def end_of_epoch(self, epoch):
|
||||||
|
if epoch % self.push_interval == 0:
|
||||||
|
self.push()
|
||||||
|
|
||||||
|
def push(self):
|
||||||
|
raise NotImplementedError("Push method should be implemented by subclasses")
|
||||||
|
|
||||||
|
class TerminalLogger(Logger):
|
||||||
|
def push(self):
|
||||||
|
for key, values in self.data.items():
|
||||||
|
for epoch, value in values:
|
||||||
|
print(f"Epoch {epoch}: {key} = {value}")
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
class WandbLogger(Logger):
|
||||||
|
def __init__(self, project, entity, config, push_interval=1):
|
||||||
|
super().__init__(push_interval)
|
||||||
|
import wandb
|
||||||
|
self.wandb = wandb
|
||||||
|
self.wandb.init(project=project, entity=entity, config=config)
|
||||||
|
|
||||||
|
def push(self):
|
||||||
|
for key, values in self.data.items():
|
||||||
|
for epoch, value in values:
|
||||||
|
self.wandb.log({key: value, 'epoch': epoch})
|
||||||
|
self.data = {}
|
131
fancy_rl/on_policy.py
Normal file
131
fancy_rl/on_policy.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import torch
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from fancy_rl.loggers import Logger
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
class OnPolicy(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy,
|
||||||
|
env_fn,
|
||||||
|
loggers,
|
||||||
|
learning_rate,
|
||||||
|
n_steps,
|
||||||
|
batch_size,
|
||||||
|
n_epochs,
|
||||||
|
gamma,
|
||||||
|
gae_lambda,
|
||||||
|
total_timesteps,
|
||||||
|
eval_interval,
|
||||||
|
eval_deterministic,
|
||||||
|
entropy_coef,
|
||||||
|
critic_coef,
|
||||||
|
normalize_advantage,
|
||||||
|
device=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.policy = policy
|
||||||
|
self.env_fn = env_fn
|
||||||
|
self.loggers = loggers
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.n_epochs = n_epochs
|
||||||
|
self.gamma = gamma
|
||||||
|
self.gae_lambda = gae_lambda
|
||||||
|
self.total_timesteps = total_timesteps
|
||||||
|
self.eval_interval = eval_interval
|
||||||
|
self.eval_deterministic = eval_deterministic
|
||||||
|
self.entropy_coef = entropy_coef
|
||||||
|
self.critic_coef = critic_coef
|
||||||
|
self.normalize_advantage = normalize_advantage
|
||||||
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.clip_range = 0.2
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.env = self.env_fn()
|
||||||
|
self.env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
|
||||||
|
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
episode_return = 0
|
||||||
|
episode_length = 0
|
||||||
|
for t in range(self.total_timesteps):
|
||||||
|
rollout = self.collect_rollouts(state)
|
||||||
|
for batch in self.get_batches(rollout):
|
||||||
|
loss = self.train_step(batch)
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log({
|
||||||
|
"loss": loss.item()
|
||||||
|
}, epoch=t)
|
||||||
|
|
||||||
|
if (t + 1) % self.eval_interval == 0:
|
||||||
|
self.evaluate(t)
|
||||||
|
|
||||||
|
def evaluate(self, epoch):
|
||||||
|
eval_env = self.env_fn()
|
||||||
|
eval_env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
returns = []
|
||||||
|
for _ in range(self.kwargs.get("eval_episodes", 10)):
|
||||||
|
state = eval_env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
done = False
|
||||||
|
total_return = 0
|
||||||
|
while not done:
|
||||||
|
with torch.no_grad():
|
||||||
|
action = (
|
||||||
|
self.policy.act(state, deterministic=self.eval_deterministic)
|
||||||
|
if self.eval_deterministic
|
||||||
|
else self.policy.act(state)
|
||||||
|
)
|
||||||
|
state, reward, done, _ = eval_env.step(action)
|
||||||
|
total_return += reward
|
||||||
|
returns.append(total_return)
|
||||||
|
|
||||||
|
avg_return = sum(returns) / len(returns)
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log({"eval_avg_return": avg_return}, epoch=epoch)
|
||||||
|
|
||||||
|
def collect_rollouts(self, state):
|
||||||
|
# Collect rollouts logic
|
||||||
|
rollouts = []
|
||||||
|
for _ in range(self.n_steps):
|
||||||
|
action = self.policy.act(state)
|
||||||
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
|
rollouts.append((state, action, reward, next_state, done))
|
||||||
|
state = next_state
|
||||||
|
if done:
|
||||||
|
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
return rollouts
|
||||||
|
|
||||||
|
def get_batches(self, rollouts):
|
||||||
|
data = self.prepare_data(rollouts)
|
||||||
|
n_batches = len(data) // self.batch_size
|
||||||
|
batches = []
|
||||||
|
for _ in range(n_batches):
|
||||||
|
batch_indices = torch.randint(0, len(data), (self.batch_size,))
|
||||||
|
batch = data[batch_indices]
|
||||||
|
batches.append(batch)
|
||||||
|
return batches
|
||||||
|
|
||||||
|
def prepare_data(self, rollouts):
|
||||||
|
obs, actions, rewards, next_obs, dones = zip(*rollouts)
|
||||||
|
obs = torch.tensor(obs, dtype=torch.float32)
|
||||||
|
actions = torch.tensor(actions, dtype=torch.int64)
|
||||||
|
rewards = torch.tensor(rewards, dtype=torch.float32)
|
||||||
|
next_obs = torch.tensor(next_obs, dtype=torch.float32)
|
||||||
|
dones = torch.tensor(dones, dtype=torch.float32)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"obs": obs,
|
||||||
|
"actions": actions,
|
||||||
|
"rewards": rewards,
|
||||||
|
"next_obs": next_obs,
|
||||||
|
"dones": dones
|
||||||
|
}
|
||||||
|
data = self.adv_module(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def train_step(self, batch):
|
||||||
|
pass
|
27
fancy_rl/policy.py
Normal file
27
fancy_rl/policy.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Policy(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, hidden_sizes=[64, 64]):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
last_dim = input_dim
|
||||||
|
for size in hidden_sizes:
|
||||||
|
layers.append(nn.Linear(last_dim, size))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
last_dim = size
|
||||||
|
layers.append(nn.Linear(last_dim, output_dim))
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
def act(self, observation, deterministic=False):
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.forward(observation)
|
||||||
|
if deterministic:
|
||||||
|
action = logits.argmax(dim=-1)
|
||||||
|
else:
|
||||||
|
action_dist = torch.distributions.Categorical(logits=logits)
|
||||||
|
action = action_dist.sample()
|
||||||
|
return action
|
98
fancy_rl/ppo.py
Normal file
98
fancy_rl/ppo.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import gymnasium as gym
|
||||||
|
from fancy_rl.policy import Policy
|
||||||
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
from fancy_rl.on_policy import OnPolicy
|
||||||
|
from torchrl.objectives import ClipPPOLoss
|
||||||
|
from torchrl.objectives.value.advantages import GAE
|
||||||
|
|
||||||
|
class PPO(OnPolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
policy,
|
||||||
|
env_fn,
|
||||||
|
loggers=None,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
n_steps=2048,
|
||||||
|
batch_size=64,
|
||||||
|
n_epochs=10,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
total_timesteps=1e6,
|
||||||
|
eval_interval=2048,
|
||||||
|
eval_deterministic=True,
|
||||||
|
entropy_coef=0.01,
|
||||||
|
critic_coef=0.5,
|
||||||
|
normalize_advantage=True,
|
||||||
|
device=None,
|
||||||
|
clip_epsilon=0.2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if loggers is None:
|
||||||
|
loggers = [TerminalLogger(push_interval=1)]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
policy=policy,
|
||||||
|
env_fn=env_fn,
|
||||||
|
loggers=loggers,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
n_steps=n_steps,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_epochs=n_epochs,
|
||||||
|
gamma=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
eval_deterministic=eval_deterministic,
|
||||||
|
entropy_coef=entropy_coef,
|
||||||
|
critic_coef=critic_coef,
|
||||||
|
normalize_advantage=normalize_advantage,
|
||||||
|
device=device,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clip_epsilon = clip_epsilon
|
||||||
|
self.adv_module = GAE(
|
||||||
|
gamma=self.gamma,
|
||||||
|
lmbda=self.gae_lambda,
|
||||||
|
value_network=self.policy,
|
||||||
|
average_gae=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss_module = ClipPPOLoss(
|
||||||
|
actor_network=self.policy,
|
||||||
|
critic_network=self.policy,
|
||||||
|
clip_epsilon=self.clip_epsilon,
|
||||||
|
loss_critic_type='MSELoss',
|
||||||
|
entropy_coef=self.entropy_coef,
|
||||||
|
critic_coef=self.critic_coef,
|
||||||
|
normalize_advantage=self.normalize_advantage,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=self.learning_rate)
|
||||||
|
|
||||||
|
def train_step(self, batch):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss = self.loss_module(batch)
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.env = self.env_fn()
|
||||||
|
self.env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
|
||||||
|
state = self.env.reset(seed=self.kwargs.get("seed", None))
|
||||||
|
episode_return = 0
|
||||||
|
episode_length = 0
|
||||||
|
for t in range(self.total_timesteps):
|
||||||
|
rollout = self.collect_rollouts(state)
|
||||||
|
for batch in self.get_batches(rollout):
|
||||||
|
loss = self.train_step(batch)
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log({
|
||||||
|
"loss": loss.item()
|
||||||
|
}, epoch=t)
|
||||||
|
|
||||||
|
if (t + 1) % self.eval_interval == 0:
|
||||||
|
self.evaluate(t)
|
4
fancy_rl/utils.py
Normal file
4
fancy_rl/utils.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
def make_env(env_name):
|
||||||
|
return lambda: gym.make(env_name)
|
19
setup.py
Normal file
19
setup.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="fancy_rl",
|
||||||
|
version="0.1",
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=[
|
||||||
|
"torch",
|
||||||
|
"torchrl",
|
||||||
|
"gymnasium",
|
||||||
|
"pyyaml",
|
||||||
|
],
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"fancy_rl=fancy_rl.example:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
54
test/test_ppo.py
Normal file
54
test/test_ppo.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from fancy_rl.ppo import PPO
|
||||||
|
from fancy_rl.policy import Policy
|
||||||
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
from fancy_rl.utils import make_env
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def policy():
|
||||||
|
return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64])
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def loggers():
|
||||||
|
return [TerminalLogger()]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def env_fn():
|
||||||
|
return make_env("CartPole-v1")
|
||||||
|
|
||||||
|
def test_ppo_train(policy, loggers, env_fn):
|
||||||
|
ppo = PPO(policy=policy,
|
||||||
|
env_fn=env_fn,
|
||||||
|
loggers=loggers,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
n_steps=2048,
|
||||||
|
batch_size=64,
|
||||||
|
n_epochs=10,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
clip_range=0.2,
|
||||||
|
total_timesteps=10000,
|
||||||
|
eval_interval=2048,
|
||||||
|
eval_deterministic=True,
|
||||||
|
eval_episodes=5,
|
||||||
|
seed=42)
|
||||||
|
ppo.train()
|
||||||
|
|
||||||
|
def test_ppo_evaluate(policy, loggers, env_fn):
|
||||||
|
ppo = PPO(policy=policy,
|
||||||
|
env_fn=env_fn,
|
||||||
|
loggers=loggers,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
n_steps=2048,
|
||||||
|
batch_size=64,
|
||||||
|
n_epochs=10,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
clip_range=0.2,
|
||||||
|
total_timesteps=10000,
|
||||||
|
eval_interval=2048,
|
||||||
|
eval_deterministic=True,
|
||||||
|
eval_episodes=5,
|
||||||
|
seed=42)
|
||||||
|
ppo.evaluate(epoch=0)
|
Loading…
Reference in New Issue
Block a user