diff --git a/fancy_rl/algos/__init__.py b/fancy_rl/algos/__init__.py index 2ae3d06..040bf9c 100644 --- a/fancy_rl/algos/__init__.py +++ b/fancy_rl/algos/__init__.py @@ -1 +1,3 @@ -from fancy_rl.algos.ppo import PPO \ No newline at end of file +from fancy_rl.algos.ppo import PPO +from fancy_rl.algos.trpl import TRPL +from fancy_rl.algos.vlearn import VLEARN \ No newline at end of file diff --git a/fancy_rl/algos/on_policy.py b/fancy_rl/algos/on_policy.py index 9b61c54..9da5355 100644 --- a/fancy_rl/algos/on_policy.py +++ b/fancy_rl/algos/on_policy.py @@ -15,23 +15,17 @@ class OnPolicy(Algo): env_spec, optimizers, loggers=None, - actor_hidden_sizes=[64, 64], - critic_hidden_sizes=[64, 64], - actor_activation_fn="Tanh", - critic_activation_fn="Tanh", learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, - gae_lambda=0.95, total_timesteps=1e6, eval_interval=2048, eval_deterministic=True, entropy_coef=0.01, critic_coef=0.5, normalize_advantage=True, - clip_range=0.2, env_spec_eval=None, eval_episodes=10, device=None, @@ -77,15 +71,25 @@ class OnPolicy(Algo): batch_size=self.batch_size, ) + def pre_process_batch(self, batch): + return batch + + def post_process_batch(self, batch): + pass def train_step(self, batch): + batch = self.pre_process_batch(batch) + for optimizer in self.optimizers.values(): optimizer.zero_grad() losses = self.loss_module(batch) - loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"] + loss = sum(losses.values()) # Sum all losses loss.backward() for optimizer in self.optimizers.values(): optimizer.step() + + self.post_process_batch(batch) + return loss def train(self): diff --git a/fancy_rl/algos/ppo.py b/fancy_rl/algos/ppo.py index c74a546..d776811 100644 --- a/fancy_rl/algos/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -4,6 +4,7 @@ from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic +from fancy_rl.projections import get_projection # Updated import class PPO(OnPolicy): def __init__( diff --git a/fancy_rl/algos/trpl.py b/fancy_rl/algos/trpl.py index 9dafb94..8160e2b 100644 --- a/fancy_rl/algos/trpl.py +++ b/fancy_rl/algos/trpl.py @@ -1,9 +1,16 @@ import torch -from torchrl.modules import ProbabilisticActor -from torchrl.objectives.value.advantages import GAE +from torch import nn +from typing import Dict, Any, Optional +from torchrl.modules import ProbabilisticActor, ValueOperator +from torchrl.objectives import ClipPPOLoss +from torchrl.collectors import SyncDataCollector +from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement +from torchrl.objectives.value import GAE from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic +from fancy_rl.projections import get_projection, BaseProjection from fancy_rl.objectives import TRPLLoss +from copy import deepcopy class TRPL(OnPolicy): def __init__( @@ -14,19 +21,21 @@ class TRPL(OnPolicy): critic_hidden_sizes=[64, 64], actor_activation_fn="Tanh", critic_activation_fn="Tanh", - proj_layer_type=None, learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, + projection_class="identity_projection", + trust_region_coef=10.0, + trust_region_bound_mean=0.1, + trust_region_bound_cov=0.001, total_timesteps=1e6, eval_interval=2048, eval_deterministic=True, entropy_coef=0.01, critic_coef=0.5, - trust_region_coef=10.0, normalize_advantage=False, device=None, env_spec_eval=None, @@ -35,9 +44,6 @@ class TRPL(OnPolicy): device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device - self.trust_region_layer = None # TODO: from proj_layer_type - self.trust_region_coef = trust_region_coef - # Initialize environment to get observation and action space sizes self.env_spec = env_spec env = self.make_env() @@ -46,14 +52,40 @@ class TRPL(OnPolicy): self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) - raw_actor = ProbabilisticActor( - module=actor_net, + + # Handle projection_class + if isinstance(projection_class, str): + projection_class = get_projection(projection_class) + elif not issubclass(projection_class, BaseProjection): + raise ValueError("projection_class must be a string or a subclass of BaseProjection") + + self.projection = projection_class( in_keys=["loc", "scale"], - out_keys=["action"], + out_keys=["loc", "scale"], + trust_region_bound_mean=trust_region_bound_mean, + trust_region_bound_cov=trust_region_bound_cov + ) + + self.actor = ProbabilisticActor( + module=actor_net, + in_keys=["observation"], + out_keys=["loc", "scale"], distribution_class=torch.distributions.Normal, return_log_prob=True ) - self.actor = raw_actor # TODO: Proj here + self.old_actor = deepcopy(self.actor) + + self.trust_region_coef = trust_region_coef + self.loss_module = TRPLLoss( + actor_network=self.actor, + old_actor_network=self.old_actor, + critic_network=self.critic, + projection=self.projection, + entropy_coef=entropy_coef, + critic_coef=critic_coef, + trust_region_coef=trust_region_coef, + normalize_advantage=normalize_advantage, + ) optimizers = { "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), @@ -79,7 +111,6 @@ class TRPL(OnPolicy): env_spec_eval=env_spec_eval, eval_episodes=eval_episodes, ) - self.adv_module = GAE( gamma=self.gamma, lmbda=gae_lambda, @@ -87,13 +118,24 @@ class TRPL(OnPolicy): average_gae=False, ) - self.loss_module = TRPLLoss( - actor_network=self.actor, - critic_network=self.critic, - trust_region_layer=self.trust_region_layer, - loss_critic_type='l2', - entropy_coef=self.entropy_coef, - critic_coef=self.critic_coef, - trust_region_coef=self.trust_region_coef, - normalize_advantage=self.normalize_advantage, - ) + def update_old_policy(self): + self.old_actor.load_state_dict(self.actor.state_dict()) + + def project_policy(self, obs): + with torch.no_grad(): + old_dist = self.old_actor(obs) + new_dist = self.actor(obs) + projected_params = self.projection.project(new_dist, old_dist) + return projected_params + + def pre_update(self, tensordict): + obs = tensordict["observation"] + projected_dist = self.project_policy(obs) + + # Update tensordict with projected distribution parameters + tensordict["projected_loc"] = projected_dist[0] + tensordict["projected_scale"] = projected_dist[1] + return tensordict + + def post_update(self): + self.update_old_policy()