Compare commits
2 Commits
5f186af9fb
...
4240f611ac
Author | SHA1 | Date | |
---|---|---|---|
4240f611ac | |||
8d5d44e992 |
@ -9,7 +9,7 @@
|
||||
Fancy RL provides a minimalistic and efficient implementation of Proximal Policy Optimization (PPO) and Trust Region Policy Layers (TRPL) using primitives from [torchrl](https://pypi.org/project/torchrl/). This library focuses on providing clean, understandable code and reusable modules while leveraging the powerful functionalities of torchrl.
|
||||
|
||||
| :exclamation: This project is still WIP and not ready to be used. |
|
||||
| ------------------------------------------------------------ |
|
||||
| ----------------------------------------------------------------- |
|
||||
|
||||
## Installation
|
||||
|
||||
@ -54,6 +54,7 @@ pytest test/test_ppo.py
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Better / more logging
|
||||
- [ ] Test / Benchmark PPO
|
||||
- [ ] Refactor Modules for TRPL
|
||||
- [ ] Get TRPL working
|
||||
|
75
fancy_rl/algos/algo.py
Normal file
75
fancy_rl/algos/algo.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
class Algo(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
env_spec,
|
||||
loggers,
|
||||
optimizers,
|
||||
learning_rate,
|
||||
n_steps,
|
||||
batch_size,
|
||||
n_epochs,
|
||||
gamma,
|
||||
total_timesteps,
|
||||
eval_interval,
|
||||
eval_deterministic,
|
||||
entropy_coef,
|
||||
critic_coef,
|
||||
normalize_advantage,
|
||||
device=None,
|
||||
eval_episodes=10,
|
||||
env_spec_eval=None,
|
||||
):
|
||||
self.env_spec = env_spec
|
||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
||||
self.optimizers = optimizers
|
||||
self.learning_rate = learning_rate
|
||||
self.n_steps = n_steps
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.gamma = gamma
|
||||
self.total_timesteps = total_timesteps
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_deterministic = eval_deterministic
|
||||
self.entropy_coef = entropy_coef
|
||||
self.critic_coef = critic_coef
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.eval_episodes = eval_episodes
|
||||
|
||||
def make_env(self, eval=False):
|
||||
"""Creates an environment and wraps it if necessary."""
|
||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||
if isinstance(env_spec, str):
|
||||
env = gym.make(env_spec)
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif callable(env_spec):
|
||||
env = env_spec()
|
||||
if isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
else:
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
raise NotImplementedError("train_step method must be implemented in subclass.")
|
||||
|
||||
def train(self):
|
||||
raise NotImplementedError("train method must be implemented in subclass.")
|
||||
|
||||
def evaluate(self, epoch):
|
||||
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||
|
||||
def dump_video(module):
|
||||
if isinstance(module, VideoRecorder):
|
||||
module.dump()
|
@ -5,50 +5,58 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.envs import ExplorationType, set_exploration_type
|
||||
from torchrl.record import VideoRecorder
|
||||
from tensordict import LazyStackedTensorDict, TensorDict
|
||||
from abc import ABC
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
from fancy_rl.algos.algo import Algo
|
||||
|
||||
class OnPolicy(ABC):
|
||||
class OnPolicy(Algo):
|
||||
def __init__(
|
||||
self,
|
||||
env_spec,
|
||||
loggers,
|
||||
optimizers,
|
||||
learning_rate,
|
||||
n_steps,
|
||||
batch_size,
|
||||
n_epochs,
|
||||
gamma,
|
||||
total_timesteps,
|
||||
eval_interval,
|
||||
eval_deterministic,
|
||||
entropy_coef,
|
||||
critic_coef,
|
||||
normalize_advantage,
|
||||
device=None,
|
||||
eval_episodes=10,
|
||||
loggers=None,
|
||||
actor_hidden_sizes=[64, 64],
|
||||
critic_hidden_sizes=[64, 64],
|
||||
actor_activation_fn="Tanh",
|
||||
critic_activation_fn="Tanh",
|
||||
learning_rate=3e-4,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
total_timesteps=1e6,
|
||||
eval_interval=2048,
|
||||
eval_deterministic=True,
|
||||
entropy_coef=0.01,
|
||||
critic_coef=0.5,
|
||||
normalize_advantage=True,
|
||||
clip_range=0.2,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
device=None,
|
||||
):
|
||||
self.env_spec = env_spec
|
||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
||||
self.optimizers = optimizers
|
||||
self.learning_rate = learning_rate
|
||||
self.n_steps = n_steps
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.gamma = gamma
|
||||
self.total_timesteps = total_timesteps
|
||||
self.eval_interval = eval_interval
|
||||
self.eval_deterministic = eval_deterministic
|
||||
self.entropy_coef = entropy_coef
|
||||
self.critic_coef = critic_coef
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.eval_episodes = eval_episodes
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
super().__init__(
|
||||
env_spec=env_spec,
|
||||
loggers=loggers,
|
||||
optimizers=optimizers,
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=n_epochs,
|
||||
gamma=gamma,
|
||||
total_timesteps=total_timesteps,
|
||||
eval_interval=eval_interval,
|
||||
eval_deterministic=eval_deterministic,
|
||||
entropy_coef=entropy_coef,
|
||||
critic_coef=critic_coef,
|
||||
normalize_advantage=normalize_advantage,
|
||||
device=device,
|
||||
env_spec_eval=env_spec_eval,
|
||||
eval_episodes=eval_episodes,
|
||||
)
|
||||
|
||||
# Create collector
|
||||
self.collector = SyncDataCollector(
|
||||
@ -69,21 +77,6 @@ class OnPolicy(ABC):
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
def make_env(self, eval=False):
|
||||
"""Creates an environment and wraps it if necessary."""
|
||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||
if isinstance(env_spec, str):
|
||||
env = gym.make(env_spec)
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif callable(env_spec):
|
||||
env = env_spec()
|
||||
if isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
else:
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
for optimizer in self.optimizers.values():
|
||||
@ -140,7 +133,3 @@ class OnPolicy(ABC):
|
||||
avg_return = torch.cat(test_rewards, 0).mean().item()
|
||||
for logger in self.loggers:
|
||||
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
||||
|
||||
def dump_video(module):
|
||||
if isinstance(module, VideoRecorder):
|
||||
module.dump()
|
||||
|
Loading…
Reference in New Issue
Block a user