refactor algo impls

This commit is contained in:
Dominik Moritz Roth 2024-08-28 11:33:20 +02:00
parent dd98af9f77
commit 4f58ce0ff2
4 changed files with 79 additions and 30 deletions

View File

@ -1 +1,3 @@
from fancy_rl.algos.ppo import PPO from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
from fancy_rl.algos.vlearn import VLEARN

View File

@ -15,23 +15,17 @@ class OnPolicy(Algo):
env_spec, env_spec,
optimizers, optimizers,
loggers=None, loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,
n_epochs=10, n_epochs=10,
gamma=0.99, gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6, total_timesteps=1e6,
eval_interval=2048, eval_interval=2048,
eval_deterministic=True, eval_deterministic=True,
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
normalize_advantage=True, normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
device=None, device=None,
@ -77,15 +71,25 @@ class OnPolicy(Algo):
batch_size=self.batch_size, batch_size=self.batch_size,
) )
def pre_process_batch(self, batch):
return batch
def post_process_batch(self, batch):
pass
def train_step(self, batch): def train_step(self, batch):
batch = self.pre_process_batch(batch)
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.zero_grad() optimizer.zero_grad()
losses = self.loss_module(batch) losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"] loss = sum(losses.values()) # Sum all losses
loss.backward() loss.backward()
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.step() optimizer.step()
self.post_process_batch(batch)
return loss return loss
def train(self): def train(self):

View File

@ -4,6 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(

View File

@ -1,9 +1,16 @@
import torch import torch
from torchrl.modules import ProbabilisticActor from torch import nn
from torchrl.objectives.value.advantages import GAE from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
@ -14,19 +21,21 @@ class TRPL(OnPolicy):
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
proj_layer_type=None,
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,
n_epochs=10, n_epochs=10,
gamma=0.99, gamma=0.99,
gae_lambda=0.95, gae_lambda=0.95,
projection_class="identity_projection",
trust_region_coef=10.0,
trust_region_bound_mean=0.1,
trust_region_bound_cov=0.001,
total_timesteps=1e6, total_timesteps=1e6,
eval_interval=2048, eval_interval=2048,
eval_deterministic=True, eval_deterministic=True,
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
trust_region_coef=10.0,
normalize_advantage=False, normalize_advantage=False,
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
@ -35,9 +44,6 @@ class TRPL(OnPolicy):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device self.device = device
self.trust_region_layer = None # TODO: from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
@ -46,14 +52,40 @@ class TRPL(OnPolicy):
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
raw_actor = ProbabilisticActor(
module=actor_net, # Handle projection_class
if isinstance(projection_class, str):
projection_class = get_projection(projection_class)
elif not issubclass(projection_class, BaseProjection):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
return_log_prob=True return_log_prob=True
) )
self.actor = raw_actor # TODO: Proj here self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
actor_network=self.actor,
old_actor_network=self.old_actor,
critic_network=self.critic,
projection=self.projection,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
trust_region_coef=trust_region_coef,
normalize_advantage=normalize_advantage,
)
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -79,7 +111,6 @@ class TRPL(OnPolicy):
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
) )
self.adv_module = GAE( self.adv_module = GAE(
gamma=self.gamma, gamma=self.gamma,
lmbda=gae_lambda, lmbda=gae_lambda,
@ -87,13 +118,24 @@ class TRPL(OnPolicy):
average_gae=False, average_gae=False,
) )
self.loss_module = TRPLLoss( def update_old_policy(self):
actor_network=self.actor, self.old_actor.load_state_dict(self.actor.state_dict())
critic_network=self.critic,
trust_region_layer=self.trust_region_layer, def project_policy(self, obs):
loss_critic_type='l2', with torch.no_grad():
entropy_coef=self.entropy_coef, old_dist = self.old_actor(obs)
critic_coef=self.critic_coef, new_dist = self.actor(obs)
trust_region_coef=self.trust_region_coef, projected_params = self.projection.project(new_dist, old_dist)
normalize_advantage=self.normalize_advantage, return projected_params
)
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
def post_update(self):
self.update_old_policy()