diff --git a/fancy_rl/__init__.py b/fancy_rl/__init__.py index efd1a8b..c0e7081 100644 --- a/fancy_rl/__init__.py +++ b/fancy_rl/__init__.py @@ -1,6 +1,9 @@ -from fancy_rl.ppo import PPO -from fancy_rl.policy import MLPPolicy -from fancy_rl.loggers import TerminalLogger, WandbLogger -from fancy_rl.utils import make_env +import gymnasium +try: + import fancy_gym +except ImportError: + pass -__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"] +from fancy_rl.ppo import PPO + +__all__ = ["PPO"] \ No newline at end of file diff --git a/fancy_rl/on_policy.py b/fancy_rl/on_policy.py index 6093e29..c090b48 100644 --- a/fancy_rl/on_policy.py +++ b/fancy_rl/on_policy.py @@ -1,18 +1,13 @@ import torch -from abc import ABC, abstractmethod -from torchrl.record.loggers import Logger -from torch.optim import Adam +import gymnasium as gym from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement -from torchrl.envs import ExplorationType, set_exploration_type from torchrl.envs.libs.gym import GymWrapper +from torchrl.envs import ExplorationType, set_exploration_type from torchrl.record import VideoRecorder -import gymnasium as gym -try: - import fancy_gym -except ImportError: - pass +from abc import ABC, abstractmethod + class OnPolicy(ABC): def __init__( @@ -20,6 +15,7 @@ class OnPolicy(ABC): policy, env_spec, loggers, + optimizers, learning_rate, n_steps, batch_size, @@ -41,6 +37,7 @@ class OnPolicy(ABC): self.env_spec = env_spec self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec self.loggers = loggers + self.optimizers = optimizers self.learning_rate = learning_rate self.n_steps = n_steps self.batch_size = batch_size @@ -90,6 +87,15 @@ class OnPolicy(ABC): raise ValueError("env_spec must be a string or a callable that returns an environment.") return env + def train_step(self, batch): + for optimizer in self.optimizers.values(): + optimizer.zero_grad() + loss = self.loss_module(batch) + loss.backward() + for optimizer in self.optimizers.values(): + optimizer.step() + return loss + def train(self): collected_frames = 0 @@ -136,10 +142,6 @@ class OnPolicy(ABC): for logger in self.loggers: logger.log_scalar({"eval_avg_return": avg_return}, step=epoch) - @abstractmethod - def train_step(self, batch): - pass - def dump_video(module): if isinstance(module, VideoRecorder): module.dump() diff --git a/fancy_rl/policy.py b/fancy_rl/policy.py index 5875986..9146eb0 100644 --- a/fancy_rl/policy.py +++ b/fancy_rl/policy.py @@ -1,71 +1,59 @@ -import torch -from torch import nn -from torch.distributions import Categorical, Normal -import gymnasium as gym +import torch.nn as nn +from tensordict.nn import TensorDictModule +from torchrl.modules import MLP +from tensordict.nn.distributions import NormalParamExtractor -class Actor(nn.Module): - def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU): - super().__init__() - self.continuous = isinstance(action_space, gym.spaces.Box) - input_dim = observation_space.shape[-1] - if self.continuous: - output_dim = action_space.shape[-1] +class SharedModule(TensorDictModule): + def __init__(self, obs_space, hidden_sizes, activation_fn, device): + if hidden_sizes: + shared_module = MLP( + in_features=obs_space.shape[-1], + out_features=hidden_sizes[-1], + num_cells=hidden_sizes, + activation_class=getattr(nn, activation_fn), + device=device + ) + out_features = hidden_sizes[-1] else: - output_dim = action_space.n + shared_module = nn.Identity() + out_features = obs_space.shape[-1] - layers = [] - last_dim = input_dim - for size in hidden_sizes: - layers.append(nn.Linear(last_dim, size)) - layers.append(activation_fn()) - last_dim = size - - if self.continuous: - self.mu_layer = nn.Linear(last_dim, output_dim) - self.log_std_layer = nn.Linear(last_dim, output_dim) - else: - layers.append(nn.Linear(last_dim, output_dim)) - self.model = nn.Sequential(*layers) + super().__init__( + module=shared_module, + in_keys=["observation"], + out_keys=["shared"], + ) + self.out_features = out_features - def forward(self, x): - if self.continuous: - mu = self.mu_layer(x) - log_std = self.log_std_layer(x) - return mu, log_std.exp() - else: - return self.model(x) +class Actor(TensorDictModule): + def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): + actor_module = nn.Sequential( + MLP( + in_features=shared_module.out_features, + out_features=act_space.shape[-1] * 2, + num_cells=hidden_sizes, + activation_class=getattr(nn, activation_fn), + device=device + ), + NormalParamExtractor(), + ).to(device) + super().__init__( + module=actor_module, + in_keys=["shared"], + out_keys=["loc", "scale"], + ) - def act(self, observation, deterministic=False): - with torch.no_grad(): - if self.continuous: - mu, std = self.forward(observation) - if deterministic: - action = mu - else: - action_dist = Normal(mu, std) - action = action_dist.sample() - else: - logits = self.forward(observation) - if deterministic: - action = logits.argmax(dim=-1) - else: - action_dist = Categorical(logits=logits) - action = action_dist.sample() - return action - -class Critic(nn.Module): - def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU): - super().__init__() - input_dim = observation_space.shape[-1] - - layers = [] - last_dim = input_dim - for size in hidden_sizes: - layers.append(nn.Linear(last_dim, size)) - layers.append(activation_fn()) - last_dim = size - layers.append(nn.Linear(last_dim, 1)) - self.model = nn.Sequential(*layers) - - def forward(self, x): - return self.model(x).squeeze(-1) +class Critic(TensorDictModule): + def __init__(self, shared_module, hidden_sizes, activation_fn, device): + critic_module = MLP( + in_features=shared_module.out_features, + out_features=1, + num_cells=hidden_sizes, + activation_class=getattr(nn, activation_fn), + device=device + ).to(device) + super().__init__( + module=critic_module, + in_keys=["shared"], + out_keys=["state_value"], + ) diff --git a/fancy_rl/ppo.py b/fancy_rl/ppo.py index 02f36da..0459a2d 100644 --- a/fancy_rl/ppo.py +++ b/fancy_rl/ppo.py @@ -1,11 +1,9 @@ import torch -import torch.nn as nn +from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE -from torchrl.record.loggers import get_logger -from on_policy import OnPolicy -from policy import Actor, Critic -import gymnasium as gym +from fancy_rl.on_policy import OnPolicy +from fancy_rl.policy import Actor, Critic, SharedModule class PPO(OnPolicy): def __init__( @@ -14,8 +12,9 @@ class PPO(OnPolicy): loggers=None, actor_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64], - actor_activation_fn="ReLU", - critic_activation_fn="ReLU", + actor_activation_fn="Tanh", + critic_activation_fn="Tanh", + shared_stem_sizes=[64], learning_rate=3e-4, n_steps=2048, batch_size=64, @@ -33,21 +32,45 @@ class PPO(OnPolicy): env_spec_eval=None, eval_episodes=10, ): + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Initialize environment to get observation and action space sizes - env = self.make_env(env_spec) + self.env_spec = env_spec + env = self.make_env() obs_space = env.observation_space act_space = env.action_space - actor_activation_fn = getattr(nn, actor_activation_fn) - critic_activation_fn = getattr(nn, critic_activation_fn) + # Define the shared, actor, and critic modules + self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device) + self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device) + self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device) - self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn) - self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn) + # Combine into an ActorValueOperator + self.ac_module = ActorValueOperator( + self.shared_module, + self.actor, + self.critic + ) + + # Define the policy as a ProbabilisticActor + self.policy = ProbabilisticActor( + module=self.ac_module.get_policy_operator(), + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=torch.distributions.Normal, + return_log_prob=True + ) + + optimizers = { + "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), + "critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate) + } super().__init__( - policy=self.actor, + policy=self.policy, env_spec=env_spec, loggers=loggers, + optimizers=optimizers, learning_rate=learning_rate, n_steps=n_steps, batch_size=batch_size, @@ -82,15 +105,3 @@ class PPO(OnPolicy): critic_coef=self.critic_coef, normalize_advantage=self.normalize_advantage, ) - - self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate) - self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate) - - def train_step(self, batch): - self.actor_optimizer.zero_grad() - self.critic_optimizer.zero_grad() - loss = self.loss_module(batch) - loss.backward() - self.actor_optimizer.step() - self.critic_optimizer.step() - return loss