Compare commits

..

2 Commits

Author SHA1 Message Date
4240f611ac Updated TODO 2024-06-02 16:37:15 +02:00
8d5d44e992 Refactor out some func into general Algo class 2024-06-02 16:36:59 +02:00
3 changed files with 121 additions and 56 deletions

View File

@ -9,7 +9,7 @@
Fancy RL provides a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). This library focuses on providing clean, understandable code and reusable modules while leveraging the powerful functionalities of torchrl.
| :exclamation: This project is still WIP and not ready to be used. |
| ------------------------------------------------------------ |
| ----------------------------------------------------------------- |
## Installation
@ -54,6 +54,7 @@ pytest test/test_ppo.py
## TODO
- [ ] Better / more logging
- [ ] Test / Benchmark PPO
- [ ] Refactor Modules for TRPL
- [ ] Get TRPL working

75
fancy_rl/algos/algo.py Normal file
View File

@ -0,0 +1,75 @@
import torch
import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
from abc import ABC
from fancy_rl.loggers import TerminalLogger
class Algo(ABC):
def __init__(
self,
env_spec,
loggers,
optimizers,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
device=None,
eval_episodes=10,
env_spec_eval=None,
):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train_step(self, batch):
raise NotImplementedError("train_step method must be implemented in subclass.")
def train(self):
raise NotImplementedError("train method must be implemented in subclass.")
def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.")
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()

View File

@ -5,50 +5,58 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
from fancy_rl.loggers import TerminalLogger
from fancy_rl.algos.algo import Algo
class OnPolicy(ABC):
class OnPolicy(Algo):
def __init__(
self,
env_spec,
loggers,
optimizers,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
device=None,
eval_episodes=10,
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None,
eval_episodes=10,
device=None,
):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
super().__init__(
env_spec=env_spec,
loggers=loggers,
optimizers=optimizers,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
total_timesteps=total_timesteps,
eval_interval=eval_interval,
eval_deterministic=eval_deterministic,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
normalize_advantage=normalize_advantage,
device=device,
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
# Create collector
self.collector = SyncDataCollector(
@ -69,21 +77,6 @@ class OnPolicy(ABC):
batch_size=self.batch_size,
)
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train_step(self, batch):
for optimizer in self.optimizers.values():
@ -139,8 +132,4 @@ class OnPolicy(ABC):
avg_return = torch.cat(test_rewards, 0).mean().item()
for logger in self.loggers:
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)