refactor algo impls

This commit is contained in:
Dominik Moritz Roth 2024-08-28 11:33:20 +02:00
parent dd98af9f77
commit 4f58ce0ff2
4 changed files with 79 additions and 30 deletions

View File

@ -1 +1,3 @@
from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
from fancy_rl.algos.vlearn import VLEARN

View File

@ -15,23 +15,17 @@ class OnPolicy(Algo):
env_spec,
optimizers,
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None,
eval_episodes=10,
device=None,
@ -77,15 +71,25 @@ class OnPolicy(Algo):
batch_size=self.batch_size,
)
def pre_process_batch(self, batch):
return batch
def post_process_batch(self, batch):
pass
def train_step(self, batch):
batch = self.pre_process_batch(batch)
for optimizer in self.optimizers.values():
optimizer.zero_grad()
losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
loss = sum(losses.values()) # Sum all losses
loss.backward()
for optimizer in self.optimizers.values():
optimizer.step()
self.post_process_batch(batch)
return loss
def train(self):

View File

@ -4,6 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
class PPO(OnPolicy):
def __init__(

View File

@ -1,9 +1,16 @@
import torch
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.value.advantages import GAE
from torch import nn
from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
class TRPL(OnPolicy):
def __init__(
@ -14,19 +21,21 @@ class TRPL(OnPolicy):
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
proj_layer_type=None,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
projection_class="identity_projection",
trust_region_coef=10.0,
trust_region_bound_mean=0.1,
trust_region_bound_cov=0.001,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
trust_region_coef=10.0,
normalize_advantage=False,
device=None,
env_spec_eval=None,
@ -35,9 +44,6 @@ class TRPL(OnPolicy):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.trust_region_layer = None # TODO: from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
env = self.make_env()
@ -46,14 +52,40 @@ class TRPL(OnPolicy):
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
raw_actor = ProbabilisticActor(
module=actor_net,
# Handle projection_class
if isinstance(projection_class, str):
projection_class = get_projection(projection_class)
elif not issubclass(projection_class, BaseProjection):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"],
out_keys=["action"],
out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
self.actor = raw_actor # TODO: Proj here
self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
actor_network=self.actor,
old_actor_network=self.old_actor,
critic_network=self.critic,
projection=self.projection,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
trust_region_coef=trust_region_coef,
normalize_advantage=normalize_advantage,
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -79,7 +111,6 @@ class TRPL(OnPolicy):
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
self.adv_module = GAE(
gamma=self.gamma,
lmbda=gae_lambda,
@ -87,13 +118,24 @@ class TRPL(OnPolicy):
average_gae=False,
)
self.loss_module = TRPLLoss(
actor_network=self.actor,
critic_network=self.critic,
trust_region_layer=self.trust_region_layer,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
trust_region_coef=self.trust_region_coef,
normalize_advantage=self.normalize_advantage,
)
def update_old_policy(self):
self.old_actor.load_state_dict(self.actor.state_dict())
def project_policy(self, obs):
with torch.no_grad():
old_dist = self.old_actor(obs)
new_dist = self.actor(obs)
projected_params = self.projection.project(new_dist, old_dist)
return projected_params
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
def post_update(self):
self.update_old_policy()