diff --git a/fancy_rl/algos/algo.py b/fancy_rl/algos/algo.py new file mode 100644 index 0000000..b3b150f --- /dev/null +++ b/fancy_rl/algos/algo.py @@ -0,0 +1,75 @@ +import torch +import gymnasium as gym +from torchrl.envs.libs.gym import GymWrapper +from torchrl.record import VideoRecorder +from abc import ABC + +from fancy_rl.loggers import TerminalLogger + +class Algo(ABC): + def __init__( + self, + env_spec, + loggers, + optimizers, + learning_rate, + n_steps, + batch_size, + n_epochs, + gamma, + total_timesteps, + eval_interval, + eval_deterministic, + entropy_coef, + critic_coef, + normalize_advantage, + device=None, + eval_episodes=10, + env_spec_eval=None, + ): + self.env_spec = env_spec + self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec + self.loggers = loggers if loggers != None else [TerminalLogger(None, None)] + self.optimizers = optimizers + self.learning_rate = learning_rate + self.n_steps = n_steps + self.batch_size = batch_size + self.n_epochs = n_epochs + self.gamma = gamma + self.total_timesteps = total_timesteps + self.eval_interval = eval_interval + self.eval_deterministic = eval_deterministic + self.entropy_coef = entropy_coef + self.critic_coef = critic_coef + self.normalize_advantage = normalize_advantage + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") + self.eval_episodes = eval_episodes + + def make_env(self, eval=False): + """Creates an environment and wraps it if necessary.""" + env_spec = self.env_spec_eval if eval else self.env_spec + if isinstance(env_spec, str): + env = gym.make(env_spec) + env = GymWrapper(env).to(self.device) + elif callable(env_spec): + env = env_spec() + if isinstance(env, gym.Env): + env = GymWrapper(env).to(self.device) + elif isinstance(env, gym.Env): + env = GymWrapper(env).to(self.device) + else: + raise ValueError("env_spec must be a string or a callable that returns an environment.") + return env + + def train_step(self, batch): + raise NotImplementedError("train_step method must be implemented in subclass.") + + def train(self): + raise NotImplementedError("train method must be implemented in subclass.") + + def evaluate(self, epoch): + raise NotImplementedError("evaluate method must be implemented in subclass.") + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() \ No newline at end of file diff --git a/fancy_rl/algos/on_policy.py b/fancy_rl/algos/on_policy.py index cf26219..9b61c54 100644 --- a/fancy_rl/algos/on_policy.py +++ b/fancy_rl/algos/on_policy.py @@ -5,50 +5,58 @@ from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs.libs.gym import GymWrapper from torchrl.envs import ExplorationType, set_exploration_type -from torchrl.record import VideoRecorder -from tensordict import LazyStackedTensorDict, TensorDict -from abc import ABC from fancy_rl.loggers import TerminalLogger +from fancy_rl.algos.algo import Algo -class OnPolicy(ABC): +class OnPolicy(Algo): def __init__( self, env_spec, - loggers, optimizers, - learning_rate, - n_steps, - batch_size, - n_epochs, - gamma, - total_timesteps, - eval_interval, - eval_deterministic, - entropy_coef, - critic_coef, - normalize_advantage, - device=None, - eval_episodes=10, + loggers=None, + actor_hidden_sizes=[64, 64], + critic_hidden_sizes=[64, 64], + actor_activation_fn="Tanh", + critic_activation_fn="Tanh", + learning_rate=3e-4, + n_steps=2048, + batch_size=64, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + total_timesteps=1e6, + eval_interval=2048, + eval_deterministic=True, + entropy_coef=0.01, + critic_coef=0.5, + normalize_advantage=True, + clip_range=0.2, env_spec_eval=None, + eval_episodes=10, + device=None, ): - self.env_spec = env_spec - self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec - self.loggers = loggers if loggers != None else [TerminalLogger(None, None)] - self.optimizers = optimizers - self.learning_rate = learning_rate - self.n_steps = n_steps - self.batch_size = batch_size - self.n_epochs = n_epochs - self.gamma = gamma - self.total_timesteps = total_timesteps - self.eval_interval = eval_interval - self.eval_deterministic = eval_deterministic - self.entropy_coef = entropy_coef - self.critic_coef = critic_coef - self.normalize_advantage = normalize_advantage - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") - self.eval_episodes = eval_episodes + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + super().__init__( + env_spec=env_spec, + loggers=loggers, + optimizers=optimizers, + learning_rate=learning_rate, + n_steps=n_steps, + batch_size=batch_size, + n_epochs=n_epochs, + gamma=gamma, + total_timesteps=total_timesteps, + eval_interval=eval_interval, + eval_deterministic=eval_deterministic, + entropy_coef=entropy_coef, + critic_coef=critic_coef, + normalize_advantage=normalize_advantage, + device=device, + env_spec_eval=env_spec_eval, + eval_episodes=eval_episodes, + ) # Create collector self.collector = SyncDataCollector( @@ -69,21 +77,6 @@ class OnPolicy(ABC): batch_size=self.batch_size, ) - def make_env(self, eval=False): - """Creates an environment and wraps it if necessary.""" - env_spec = self.env_spec_eval if eval else self.env_spec - if isinstance(env_spec, str): - env = gym.make(env_spec) - env = GymWrapper(env).to(self.device) - elif callable(env_spec): - env = env_spec() - if isinstance(env, gym.Env): - env = GymWrapper(env).to(self.device) - elif isinstance(env, gym.Env): - env = GymWrapper(env).to(self.device) - else: - raise ValueError("env_spec must be a string or a callable that returns an environment.") - return env def train_step(self, batch): for optimizer in self.optimizers.values(): @@ -139,8 +132,4 @@ class OnPolicy(ABC): avg_return = torch.cat(test_rewards, 0).mean().item() for logger in self.loggers: - logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) - -def dump_video(module): - if isinstance(module, VideoRecorder): - module.dump() + logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) \ No newline at end of file