This commit is contained in:
Dominik Moritz Roth 2024-05-31 18:25:03 +02:00
parent 0bf748869a
commit 015f1e256a
4 changed files with 113 additions and 109 deletions

View File

@ -1,6 +1,9 @@
from fancy_rl.ppo import PPO
from fancy_rl.policy import MLPPolicy
from fancy_rl.loggers import TerminalLogger, WandbLogger
from fancy_rl.utils import make_env
import gymnasium
try:
import fancy_gym
except ImportError:
pass
__all__ = ["PPO", "MLPPolicy", "TerminalLogger", "WandbLogger", "make_env"]
from fancy_rl.ppo import PPO
__all__ = ["PPO"]

View File

@ -1,18 +1,13 @@
import torch
from abc import ABC, abstractmethod
from torchrl.record.loggers import Logger
from torch.optim import Adam
import gymnasium as gym
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs import ExplorationType, set_exploration_type
from torchrl.record import VideoRecorder
import gymnasium as gym
try:
import fancy_gym
except ImportError:
pass
from abc import ABC, abstractmethod
class OnPolicy(ABC):
def __init__(
@ -20,6 +15,7 @@ class OnPolicy(ABC):
policy,
env_spec,
loggers,
optimizers,
learning_rate,
n_steps,
batch_size,
@ -41,6 +37,7 @@ class OnPolicy(ABC):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
@ -90,6 +87,15 @@ class OnPolicy(ABC):
raise ValueError("env_spec must be a string or a callable that returns an environment.")
return env
def train_step(self, batch):
for optimizer in self.optimizers.values():
optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
for optimizer in self.optimizers.values():
optimizer.step()
return loss
def train(self):
collected_frames = 0
@ -136,10 +142,6 @@ class OnPolicy(ABC):
for logger in self.loggers:
logger.log_scalar({"eval_avg_return": avg_return}, step=epoch)
@abstractmethod
def train_step(self, batch):
pass
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()

View File

@ -1,71 +1,59 @@
import torch
from torch import nn
from torch.distributions import Categorical, Normal
import gymnasium as gym
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor
class Actor(nn.Module):
def __init__(self, observation_space, action_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
super().__init__()
self.continuous = isinstance(action_space, gym.spaces.Box)
input_dim = observation_space.shape[-1]
if self.continuous:
output_dim = action_space.shape[-1]
class SharedModule(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
if hidden_sizes:
shared_module = MLP(
in_features=obs_space.shape[-1],
out_features=hidden_sizes[-1],
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
)
out_features = hidden_sizes[-1]
else:
output_dim = action_space.n
shared_module = nn.Identity()
out_features = obs_space.shape[-1]
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(activation_fn())
last_dim = size
super().__init__(
module=shared_module,
in_keys=["observation"],
out_keys=["shared"],
)
self.out_features = out_features
if self.continuous:
self.mu_layer = nn.Linear(last_dim, output_dim)
self.log_std_layer = nn.Linear(last_dim, output_dim)
else:
layers.append(nn.Linear(last_dim, output_dim))
self.model = nn.Sequential(*layers)
class Actor(TensorDictModule):
def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device):
actor_module = nn.Sequential(
MLP(
in_features=shared_module.out_features,
out_features=act_space.shape[-1] * 2,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
),
NormalParamExtractor(),
).to(device)
super().__init__(
module=actor_module,
in_keys=["shared"],
out_keys=["loc", "scale"],
)
def forward(self, x):
if self.continuous:
mu = self.mu_layer(x)
log_std = self.log_std_layer(x)
return mu, log_std.exp()
else:
return self.model(x)
def act(self, observation, deterministic=False):
with torch.no_grad():
if self.continuous:
mu, std = self.forward(observation)
if deterministic:
action = mu
else:
action_dist = Normal(mu, std)
action = action_dist.sample()
else:
logits = self.forward(observation)
if deterministic:
action = logits.argmax(dim=-1)
else:
action_dist = Categorical(logits=logits)
action = action_dist.sample()
return action
class Critic(nn.Module):
def __init__(self, observation_space, hidden_sizes=[64, 64], activation_fn=nn.ReLU):
super().__init__()
input_dim = observation_space.shape[-1]
layers = []
last_dim = input_dim
for size in hidden_sizes:
layers.append(nn.Linear(last_dim, size))
layers.append(activation_fn())
last_dim = size
layers.append(nn.Linear(last_dim, 1))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x).squeeze(-1)
class Critic(TensorDictModule):
def __init__(self, shared_module, hidden_sizes, activation_fn, device):
critic_module = MLP(
in_features=shared_module.out_features,
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
).to(device)
super().__init__(
module=critic_module,
in_keys=["shared"],
out_keys=["state_value"],
)

View File

@ -1,11 +1,9 @@
import torch
import torch.nn as nn
from torchrl.modules import ActorValueOperator, ProbabilisticActor
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from torchrl.record.loggers import get_logger
from on_policy import OnPolicy
from policy import Actor, Critic
import gymnasium as gym
from fancy_rl.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic, SharedModule
class PPO(OnPolicy):
def __init__(
@ -14,8 +12,9 @@ class PPO(OnPolicy):
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="ReLU",
critic_activation_fn="ReLU",
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
shared_stem_sizes=[64],
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
@ -33,21 +32,45 @@ class PPO(OnPolicy):
env_spec_eval=None,
eval_episodes=10,
):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize environment to get observation and action space sizes
env = self.make_env(env_spec)
self.env_spec = env_spec
env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
actor_activation_fn = getattr(nn, actor_activation_fn)
critic_activation_fn = getattr(nn, critic_activation_fn)
# Define the shared, actor, and critic modules
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
self.actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, hidden_sizes=actor_hidden_sizes, activation_fn=actor_activation_fn)
self.critic = Critic(obs_space, hidden_sizes=critic_hidden_sizes, activation_fn=critic_activation_fn)
# Combine into an ActorValueOperator
self.ac_module = ActorValueOperator(
self.shared_module,
self.actor,
self.critic
)
# Define the policy as a ProbabilisticActor
self.policy = ProbabilisticActor(
module=self.ac_module.get_policy_operator(),
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
}
super().__init__(
policy=self.actor,
policy=self.policy,
env_spec=env_spec,
loggers=loggers,
optimizers=optimizers,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
@ -82,15 +105,3 @@ class PPO(OnPolicy):
critic_coef=self.critic_coef,
normalize_advantage=self.normalize_advantage,
)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.learning_rate)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.learning_rate)
def train_step(self, batch):
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
loss = self.loss_module(batch)
loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
return loss