Refactor out some func into general Algo class
This commit is contained in:
parent
5f186af9fb
commit
8d5d44e992
75
fancy_rl/algos/algo.py
Normal file
75
fancy_rl/algos/algo.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import torch
|
||||||
|
import gymnasium as gym
|
||||||
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
|
from torchrl.record import VideoRecorder
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
|
class Algo(ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_spec,
|
||||||
|
loggers,
|
||||||
|
optimizers,
|
||||||
|
learning_rate,
|
||||||
|
n_steps,
|
||||||
|
batch_size,
|
||||||
|
n_epochs,
|
||||||
|
gamma,
|
||||||
|
total_timesteps,
|
||||||
|
eval_interval,
|
||||||
|
eval_deterministic,
|
||||||
|
entropy_coef,
|
||||||
|
critic_coef,
|
||||||
|
normalize_advantage,
|
||||||
|
device=None,
|
||||||
|
eval_episodes=10,
|
||||||
|
env_spec_eval=None,
|
||||||
|
):
|
||||||
|
self.env_spec = env_spec
|
||||||
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
||||||
|
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
self.n_steps = n_steps
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.n_epochs = n_epochs
|
||||||
|
self.gamma = gamma
|
||||||
|
self.total_timesteps = total_timesteps
|
||||||
|
self.eval_interval = eval_interval
|
||||||
|
self.eval_deterministic = eval_deterministic
|
||||||
|
self.entropy_coef = entropy_coef
|
||||||
|
self.critic_coef = critic_coef
|
||||||
|
self.normalize_advantage = normalize_advantage
|
||||||
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.eval_episodes = eval_episodes
|
||||||
|
|
||||||
|
def make_env(self, eval=False):
|
||||||
|
"""Creates an environment and wraps it if necessary."""
|
||||||
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
||||||
|
if isinstance(env_spec, str):
|
||||||
|
env = gym.make(env_spec)
|
||||||
|
env = GymWrapper(env).to(self.device)
|
||||||
|
elif callable(env_spec):
|
||||||
|
env = env_spec()
|
||||||
|
if isinstance(env, gym.Env):
|
||||||
|
env = GymWrapper(env).to(self.device)
|
||||||
|
elif isinstance(env, gym.Env):
|
||||||
|
env = GymWrapper(env).to(self.device)
|
||||||
|
else:
|
||||||
|
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||||
|
return env
|
||||||
|
|
||||||
|
def train_step(self, batch):
|
||||||
|
raise NotImplementedError("train_step method must be implemented in subclass.")
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
raise NotImplementedError("train method must be implemented in subclass.")
|
||||||
|
|
||||||
|
def evaluate(self, epoch):
|
||||||
|
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||||
|
|
||||||
|
def dump_video(module):
|
||||||
|
if isinstance(module, VideoRecorder):
|
||||||
|
module.dump()
|
@ -5,50 +5,58 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
|||||||
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
|
||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
from torchrl.envs import ExplorationType, set_exploration_type
|
from torchrl.envs import ExplorationType, set_exploration_type
|
||||||
from torchrl.record import VideoRecorder
|
|
||||||
from tensordict import LazyStackedTensorDict, TensorDict
|
|
||||||
from abc import ABC
|
|
||||||
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
from fancy_rl.algos.algo import Algo
|
||||||
|
|
||||||
class OnPolicy(ABC):
|
class OnPolicy(Algo):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env_spec,
|
env_spec,
|
||||||
loggers,
|
|
||||||
optimizers,
|
optimizers,
|
||||||
learning_rate,
|
loggers=None,
|
||||||
n_steps,
|
actor_hidden_sizes=[64, 64],
|
||||||
batch_size,
|
critic_hidden_sizes=[64, 64],
|
||||||
n_epochs,
|
actor_activation_fn="Tanh",
|
||||||
gamma,
|
critic_activation_fn="Tanh",
|
||||||
total_timesteps,
|
learning_rate=3e-4,
|
||||||
eval_interval,
|
n_steps=2048,
|
||||||
eval_deterministic,
|
batch_size=64,
|
||||||
entropy_coef,
|
n_epochs=10,
|
||||||
critic_coef,
|
gamma=0.99,
|
||||||
normalize_advantage,
|
gae_lambda=0.95,
|
||||||
device=None,
|
total_timesteps=1e6,
|
||||||
eval_episodes=10,
|
eval_interval=2048,
|
||||||
|
eval_deterministic=True,
|
||||||
|
entropy_coef=0.01,
|
||||||
|
critic_coef=0.5,
|
||||||
|
normalize_advantage=True,
|
||||||
|
clip_range=0.2,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
|
eval_episodes=10,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
self.env_spec = env_spec
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
|
||||||
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
super().__init__(
|
||||||
self.optimizers = optimizers
|
env_spec=env_spec,
|
||||||
self.learning_rate = learning_rate
|
loggers=loggers,
|
||||||
self.n_steps = n_steps
|
optimizers=optimizers,
|
||||||
self.batch_size = batch_size
|
learning_rate=learning_rate,
|
||||||
self.n_epochs = n_epochs
|
n_steps=n_steps,
|
||||||
self.gamma = gamma
|
batch_size=batch_size,
|
||||||
self.total_timesteps = total_timesteps
|
n_epochs=n_epochs,
|
||||||
self.eval_interval = eval_interval
|
gamma=gamma,
|
||||||
self.eval_deterministic = eval_deterministic
|
total_timesteps=total_timesteps,
|
||||||
self.entropy_coef = entropy_coef
|
eval_interval=eval_interval,
|
||||||
self.critic_coef = critic_coef
|
eval_deterministic=eval_deterministic,
|
||||||
self.normalize_advantage = normalize_advantage
|
entropy_coef=entropy_coef,
|
||||||
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
critic_coef=critic_coef,
|
||||||
self.eval_episodes = eval_episodes
|
normalize_advantage=normalize_advantage,
|
||||||
|
device=device,
|
||||||
|
env_spec_eval=env_spec_eval,
|
||||||
|
eval_episodes=eval_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
# Create collector
|
# Create collector
|
||||||
self.collector = SyncDataCollector(
|
self.collector = SyncDataCollector(
|
||||||
@ -69,21 +77,6 @@ class OnPolicy(ABC):
|
|||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_env(self, eval=False):
|
|
||||||
"""Creates an environment and wraps it if necessary."""
|
|
||||||
env_spec = self.env_spec_eval if eval else self.env_spec
|
|
||||||
if isinstance(env_spec, str):
|
|
||||||
env = gym.make(env_spec)
|
|
||||||
env = GymWrapper(env).to(self.device)
|
|
||||||
elif callable(env_spec):
|
|
||||||
env = env_spec()
|
|
||||||
if isinstance(env, gym.Env):
|
|
||||||
env = GymWrapper(env).to(self.device)
|
|
||||||
elif isinstance(env, gym.Env):
|
|
||||||
env = GymWrapper(env).to(self.device)
|
|
||||||
else:
|
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
|
||||||
return env
|
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
for optimizer in self.optimizers.values():
|
for optimizer in self.optimizers.values():
|
||||||
@ -140,7 +133,3 @@ class OnPolicy(ABC):
|
|||||||
avg_return = torch.cat(test_rewards, 0).mean().item()
|
avg_return = torch.cat(test_rewards, 0).mean().item()
|
||||||
for logger in self.loggers:
|
for logger in self.loggers:
|
||||||
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
|
||||||
|
|
||||||
def dump_video(module):
|
|
||||||
if isinstance(module, VideoRecorder):
|
|
||||||
module.dump()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user