Refactor out some func into general Algo class

This commit is contained in:
Dominik Moritz Roth 2024-06-02 16:36:59 +02:00
parent 5f186af9fb
commit 8d5d44e992
2 changed files with 119 additions and 55 deletions

75
fancy_rl/algos/algo.py Normal file
View File

@ -0,0 +1,75 @@
import torch
import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
from abc import ABC
from fancy_rl.loggers import TerminalLogger
class Algo(ABC):
def __init__(
self,
env_spec,
loggers,
optimizers,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
device=None,
eval_episodes=10,
env_spec_eval=None,
):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train_step(self, batch):
raise NotImplementedError("train_step method must be implemented in subclass.")
def train(self):
raise NotImplementedError("train method must be implemented in subclass.")
def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.")
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()

View File

@ -5,50 +5,58 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymWrapper from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder
from tensordict import LazyStackedTensorDict, TensorDict
from abc import ABC
from fancy_rl.loggers import TerminalLogger from fancy_rl.loggers import TerminalLogger
from fancy_rl.algos.algo import Algo
class OnPolicy(ABC): class OnPolicy(Algo):
def __init__( def __init__(
self, self,
env_spec, env_spec,
loggers,
optimizers, optimizers,
learning_rate, loggers=None,
n_steps, actor_hidden_sizes=[64, 64],
batch_size, critic_hidden_sizes=[64, 64],
n_epochs, actor_activation_fn="Tanh",
gamma, critic_activation_fn="Tanh",
total_timesteps, learning_rate=3e-4,
eval_interval, n_steps=2048,
eval_deterministic, batch_size=64,
entropy_coef, n_epochs=10,
critic_coef, gamma=0.99,
normalize_advantage, gae_lambda=0.95,
device=None, total_timesteps=1e6,
eval_episodes=10, eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10,
device=None,
): ):
self.env_spec = env_spec device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)] super().__init__(
self.optimizers = optimizers env_spec=env_spec,
self.learning_rate = learning_rate loggers=loggers,
self.n_steps = n_steps optimizers=optimizers,
self.batch_size = batch_size learning_rate=learning_rate,
self.n_epochs = n_epochs n_steps=n_steps,
self.gamma = gamma batch_size=batch_size,
self.total_timesteps = total_timesteps n_epochs=n_epochs,
self.eval_interval = eval_interval gamma=gamma,
self.eval_deterministic = eval_deterministic total_timesteps=total_timesteps,
self.entropy_coef = entropy_coef eval_interval=eval_interval,
self.critic_coef = critic_coef eval_deterministic=eval_deterministic,
self.normalize_advantage = normalize_advantage entropy_coef=entropy_coef,
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") critic_coef=critic_coef,
self.eval_episodes = eval_episodes normalize_advantage=normalize_advantage,
device=device,
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
# Create collector # Create collector
self.collector = SyncDataCollector( self.collector = SyncDataCollector(
@ -69,21 +77,6 @@ class OnPolicy(ABC):
batch_size=self.batch_size, batch_size=self.batch_size,
) )
def make_env(self, eval=False):
"""Creates an environment and wraps it if necessary."""
env_spec = self.env_spec_eval if eval else self.env_spec
if isinstance(env_spec, str):
env = gym.make(env_spec)
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train_step(self, batch): def train_step(self, batch):
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
@ -139,8 +132,4 @@ class OnPolicy(ABC):
avg_return = torch.cat(test_rewards, 0).mean().item() avg_return = torch.cat(test_rewards, 0).mean().item()
for logger in self.loggers: for logger in self.loggers:
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()