Compare commits

...

10 Commits

23 changed files with 916 additions and 123 deletions

View File

@ -4,6 +4,7 @@ try:
except ImportError: except ImportError:
pass pass
from fancy_rl.algos import PPO from fancy_rl.algos import PPO, TRPL, VLEARN
from fancy_rl.projections import get_projection
__all__ = ["PPO"] __all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]

View File

@ -1 +1,3 @@
from fancy_rl.algos.ppo import PPO from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
from fancy_rl.algos.vlearn import VLEARN

View File

@ -15,23 +15,17 @@ class OnPolicy(Algo):
env_spec, env_spec,
optimizers, optimizers,
loggers=None, loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,
n_epochs=10, n_epochs=10,
gamma=0.99, gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6, total_timesteps=1e6,
eval_interval=2048, eval_interval=2048,
eval_deterministic=True, eval_deterministic=True,
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
normalize_advantage=True, normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None, env_spec_eval=None,
eval_episodes=10, eval_episodes=10,
device=None, device=None,
@ -77,15 +71,25 @@ class OnPolicy(Algo):
batch_size=self.batch_size, batch_size=self.batch_size,
) )
def pre_process_batch(self, batch):
return batch
def post_process_batch(self, batch):
pass
def train_step(self, batch): def train_step(self, batch):
batch = self.pre_process_batch(batch)
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.zero_grad() optimizer.zero_grad()
losses = self.loss_module(batch) losses = self.loss_module(batch)
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"] loss = sum(losses.values()) # Sum all losses
loss.backward() loss.backward()
for optimizer in self.optimizers.values(): for optimizer in self.optimizers.values():
optimizer.step() optimizer.step()
self.post_process_batch(batch)
return loss return loss
def train(self): def train(self):

View File

@ -4,6 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
class PPO(OnPolicy): class PPO(OnPolicy):
def __init__( def __init__(

View File

@ -1,9 +1,16 @@
import torch import torch
from torchrl.modules import ProbabilisticActor from torch import nn
from torchrl.objectives.value.advantages import GAE from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
@ -14,19 +21,21 @@ class TRPL(OnPolicy):
critic_hidden_sizes=[64, 64], critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh", actor_activation_fn="Tanh",
critic_activation_fn="Tanh", critic_activation_fn="Tanh",
proj_layer_type=None,
learning_rate=3e-4, learning_rate=3e-4,
n_steps=2048, n_steps=2048,
batch_size=64, batch_size=64,
n_epochs=10, n_epochs=10,
gamma=0.99, gamma=0.99,
gae_lambda=0.95, gae_lambda=0.95,
projection_class="identity_projection",
trust_region_coef=10.0,
trust_region_bound_mean=0.1,
trust_region_bound_cov=0.001,
total_timesteps=1e6, total_timesteps=1e6,
eval_interval=2048, eval_interval=2048,
eval_deterministic=True, eval_deterministic=True,
entropy_coef=0.01, entropy_coef=0.01,
critic_coef=0.5, critic_coef=0.5,
trust_region_coef=10.0,
normalize_advantage=False, normalize_advantage=False,
device=None, device=None,
env_spec_eval=None, env_spec_eval=None,
@ -35,9 +44,6 @@ class TRPL(OnPolicy):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device self.device = device
self.trust_region_layer = None # TODO: from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
@ -46,14 +52,40 @@ class TRPL(OnPolicy):
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device) actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
raw_actor = ProbabilisticActor(
module=actor_net, # Handle projection_class
if isinstance(projection_class, str):
projection_class = get_projection(projection_class)
elif not issubclass(projection_class, BaseProjection):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"], in_keys=["loc", "scale"],
out_keys=["action"], out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal, distribution_class=torch.distributions.Normal,
return_log_prob=True return_log_prob=True
) )
self.actor = raw_actor # TODO: Proj here self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
actor_network=self.actor,
old_actor_network=self.old_actor,
critic_network=self.critic,
projection=self.projection,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
trust_region_coef=trust_region_coef,
normalize_advantage=normalize_advantage,
)
optimizers = { optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate), "actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -79,7 +111,6 @@ class TRPL(OnPolicy):
env_spec_eval=env_spec_eval, env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes, eval_episodes=eval_episodes,
) )
self.adv_module = GAE( self.adv_module = GAE(
gamma=self.gamma, gamma=self.gamma,
lmbda=gae_lambda, lmbda=gae_lambda,
@ -87,13 +118,24 @@ class TRPL(OnPolicy):
average_gae=False, average_gae=False,
) )
self.loss_module = TRPLLoss( def update_old_policy(self):
actor_network=self.actor, self.old_actor.load_state_dict(self.actor.state_dict())
critic_network=self.critic,
trust_region_layer=self.trust_region_layer, def project_policy(self, obs):
loss_critic_type='l2', with torch.no_grad():
entropy_coef=self.entropy_coef, old_dist = self.old_actor(obs)
critic_coef=self.critic_coef, new_dist = self.actor(obs)
trust_region_coef=self.trust_region_coef, projected_params = self.projection.project(new_dist, old_dist)
normalize_advantage=self.normalize_advantage, return projected_params
)
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
def post_update(self):
self.update_old_policy()

114
fancy_rl/algos/vlearn.py Normal file
View File

@ -0,0 +1,114 @@
import torch
from torch import nn
from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from fancy_rl.utils import get_env, get_actor, get_critic
from fancy_rl.modules.vlearn_loss import VLEARNLoss
from fancy_rl.modules.projection import get_vlearn_projection
from fancy_rl.modules.squashed_normal import get_squashed_normal
class VLEARN:
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):
self.device = torch.device(device)
self.env = get_env(env_id)
self.projection = get_vlearn_projection(**kwargs.get("projection", {}))
actor = get_actor(self.env, **kwargs.get("actor", {}))
self.actor = ProbabilisticActor(
actor,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=get_squashed_normal(),
return_log_prob=True
).to(self.device)
self.old_actor = self.actor.clone()
self.critic = ValueOperator(
module=get_critic(self.env, **kwargs.get("critic", {})),
in_keys=["observation"]
).to(self.device)
self.collector = SyncDataCollector(
self.env,
self.actor,
frames_per_batch=kwargs.get("frames_per_batch", 1000),
total_frames=kwargs.get("total_frames", -1),
device=self.device,
)
self.replay_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(kwargs.get("buffer_size", 100000)),
batch_size=kwargs.get("batch_size", 256),
)
self.loss_module = VLEARNLoss(
actor_network=self.actor,
critic_network=self.critic,
old_actor_network=self.old_actor,
projection=self.projection,
**kwargs.get("loss", {})
)
self.optimizers = nn.ModuleDict({
"policy": torch.optim.Adam(self.actor.parameters(), lr=kwargs.get("lr_policy", 3e-4)),
"critic": torch.optim.Adam(self.critic.parameters(), lr=kwargs.get("lr_critic", 3e-4))
})
self.update_policy_interval = kwargs.get("update_policy_interval", 1)
self.update_critic_interval = kwargs.get("update_critic_interval", 1)
self.target_update_interval = kwargs.get("target_update_interval", 1)
self.polyak_weight_critic = kwargs.get("polyak_weight_critic", 0.995)
def train(self, num_iterations: int = 1000) -> None:
for i in range(num_iterations):
data = next(self.collector)
self.replay_buffer.extend(data)
batch = self.replay_buffer.sample().to(self.device)
loss_dict = self.loss_module(batch)
if i % self.update_policy_interval == 0:
self.optimizers["policy"].zero_grad()
loss_dict["policy_loss"].backward()
self.optimizers["policy"].step()
if i % self.update_critic_interval == 0:
self.optimizers["critic"].zero_grad()
loss_dict["critic_loss"].backward()
self.optimizers["critic"].step()
if i % self.target_update_interval == 0:
self.critic.update_target_params(self.polyak_weight_critic)
self.old_actor.load_state_dict(self.actor.state_dict())
self.collector.update_policy_weights_()
if i % 100 == 0:
eval_reward = self.eval()
print(f"Iteration {i}, Eval reward: {eval_reward}")
def eval(self, num_episodes: int = 10) -> float:
total_reward = 0
for _ in range(num_episodes):
td = self.env.reset()
done = False
while not done:
with torch.no_grad():
action = self.actor(td.to(self.device))["action"]
td = self.env.step(action)
total_reward += td["reward"].item()
done = td["done"].item()
return total_reward / num_episodes
def save_policy(self, path: str) -> None:
torch.save(self.actor.state_dict(), f"{path}/actor.pth")
torch.save(self.critic.state_dict(), f"{path}/critic.pth")
def load_policy(self, path: str) -> None:
self.actor.load_state_dict(torch.load(f"{path}/actor.pth"))
self.critic.load_state_dict(torch.load(f"{path}/critic.pth"))
self.old_actor.load_state_dict(self.actor.state_dict())

View File

@ -38,100 +38,40 @@ from torchrl.objectives.value import (
) )
from torchrl.objectives.ppo import PPOLoss from torchrl.objectives.ppo import PPOLoss
from fancy_rl.projections import get_projection
class TRPLLoss(PPOLoss): class TRPLLoss(PPOLoss):
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the
sample log probability is expected. Defaults to ``"sample_log_prob"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""
advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.GAE
actor_network: TensorDictModule
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams
def __init__( def __init__(
self, self,
actor_network: ProbabilisticTensorDictSequential | None = None, actor_network: ProbabilisticTensorDictSequential,
critic_network: TensorDictModule | None = None, old_actor_network: ProbabilisticTensorDictSequential,
trust_region_layer: any | None = None, critic_network: TensorDictModule,
entropy_bonus: bool = True, projection: any,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01, entropy_coef: float = 0.01,
critic_coef: float = 1.0, critic_coef: float = 1.0,
trust_region_coef: float = 10.0, trust_region_coef: float = 10.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False, normalize_advantage: bool = False,
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
clip_value: bool | float | None = None,
**kwargs, **kwargs,
): ):
self.trust_region_layer = trust_region_layer super().__init__(
self.trust_region_coef = trust_region_coef actor_network=actor_network,
critic_network=critic_network,
super(TRPLLoss, self).__init__(
actor_network,
critic_network,
entropy_bonus=entropy_bonus,
samples_mc_entropy=samples_mc_entropy,
entropy_coef=entropy_coef, entropy_coef=entropy_coef,
critic_coef=critic_coef, critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage, normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
clip_value=clip_value,
**kwargs, **kwargs,
) )
self.old_actor_network = old_actor_network
self.projection = projection
self.trust_region_coef = trust_region_coef
@property @property
def out_keys(self): def out_keys(self):
if self._out_keys is None: if self._out_keys is None:
keys = ["loss_objective"] keys = ["loss_objective", "tr_loss"]
if self.entropy_bonus: if self.entropy_bonus:
keys.extend(["entropy", "loss_entropy"]) keys.extend(["entropy", "loss_entropy"])
if self.loss_critic: if self.critic_coef:
keys.append("loss_critic") keys.append("loss_critic")
keys.append("ESS") keys.append("ESS")
self._out_keys = keys self._out_keys = keys
@ -141,8 +81,12 @@ class TRPLLoss(PPOLoss):
def out_keys(self, values): def out_keys(self, values):
self._out_keys = values self._out_keys = values
@dispatch def _trust_region_loss(self, tensordict):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: old_distribution = self.old_actor_network(tensordict)
raw_distribution = self.actor_network(tensordict)
return self.projection(self.actor_network, raw_distribution, old_distribution)
def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict = tensordict.clone(False) tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None) advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None: if advantage is None:
@ -159,11 +103,8 @@ class TRPLLoss(PPOLoss):
log_weight, dist, kl_approx = self._log_weight(tensordict) log_weight, dist, kl_approx = self._log_weight(tensordict)
trust_region_loss_unscaled = self._trust_region_loss(tensordict) trust_region_loss_unscaled = self._trust_region_loss(tensordict)
# ESS for logging
with torch.no_grad(): with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
lw = log_weight.squeeze() lw = log_weight.squeeze()
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
batch = log_weight.shape[0] batch = log_weight.shape[0]
@ -194,8 +135,3 @@ class TRPLLoss(PPOLoss):
batch_size=[], batch_size=[],
) )
return td_out return td_out
def _trust_region_loss(self, tensordict):
old_distribution =
raw_distribution =
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)

View File

@ -0,0 +1,106 @@
import torch
from torchrl.objectives import LossModule
from torch.distributions import Normal
class VLEARNLoss(LossModule):
def __init__(
self,
actor_network,
critic_network,
old_actor_network,
gamma=0.99,
lmbda=0.95,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
eps=1e-8,
delta=0.1
):
super().__init__()
self.actor_network = actor_network
self.critic_network = critic_network
self.old_actor_network = old_actor_network
self.gamma = gamma
self.lmbda = lmbda
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.eps = eps
self.delta = delta
def forward(self, tensordict):
# Compute returns and advantages
with torch.no_grad():
returns = self.compute_returns(tensordict)
values = self.critic_network(tensordict)["state_value"]
advantages = returns - values
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute actor loss
new_td = self.actor_network(tensordict)
old_td = self.old_actor_network(tensordict)
new_dist = Normal(new_td["loc"], new_td["scale"])
old_dist = Normal(old_td["loc"], old_td["scale"])
new_log_prob = new_dist.log_prob(tensordict["action"]).sum(-1)
old_log_prob = old_dist.log_prob(tensordict["action"]).sum(-1)
ratio = torch.exp(new_log_prob - old_log_prob)
# Compute projection
kl = torch.distributions.kl.kl_divergence(new_dist, old_dist).sum(-1)
alpha = torch.where(kl > self.delta,
torch.sqrt(self.delta / (kl + self.eps)),
torch.ones_like(kl))
proj_loc = alpha.unsqueeze(-1) * new_td["loc"] + (1 - alpha.unsqueeze(-1)) * old_td["loc"]
proj_scale = torch.sqrt(alpha.unsqueeze(-1)**2 * new_td["scale"]**2 + (1 - alpha.unsqueeze(-1))**2 * old_td["scale"]**2)
proj_dist = Normal(proj_loc, proj_scale)
proj_log_prob = proj_dist.log_prob(tensordict["action"]).sum(-1)
proj_ratio = torch.exp(proj_log_prob - old_log_prob)
policy_loss = -torch.min(
ratio * advantages,
proj_ratio * advantages
).mean()
# Compute critic loss
value_pred = self.critic_network(tensordict)["state_value"]
critic_loss = 0.5 * (returns - value_pred).pow(2).mean()
# Compute entropy loss
entropy_loss = -self.entropy_coef * new_dist.entropy().mean()
# Combine losses
loss = policy_loss + self.critic_coef * critic_loss + entropy_loss
return {
"loss": loss,
"policy_loss": policy_loss,
"critic_loss": critic_loss,
"entropy_loss": entropy_loss,
}
def compute_returns(self, tensordict):
rewards = tensordict["reward"]
dones = tensordict["done"]
values = self.critic_network(tensordict)["state_value"]
returns = torch.zeros_like(rewards)
advantages = torch.zeros_like(rewards)
last_gae_lam = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
advantages[t] = last_gae_lam = delta + self.gamma * self.lmbda * (1 - dones[t]) * last_gae_lam
returns = advantages + values
return returns

View File

@ -1,6 +1,20 @@
try: from .base_projection import BaseProjection
import cpp_projection from .identity_projection import IdentityProjection
except ModuleNotFoundError: from .kl_projection import KLProjection
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer from .wasserstein_projection import WassersteinProjection
else: from .frobenius_projection import FrobeniusProjection
from .kl_projection_layer import KLProjectionLayer
def get_projection(projection_name: str):
projections = {
"identity_projection": IdentityProjection,
"kl_projection": KLProjection,
"wasserstein_projection": WassersteinProjection,
"frobenius_projection": FrobeniusProjection,
}
projection = projections.get(projection_name.lower())
if projection is None:
raise ValueError(f"Unknown projection: {projection_name}")
return projection
__all__ = ["BaseProjection", "IdentityProjection", "KLProjection", "WassersteinProjection", "FrobeniusProjection", "get_projection"]

View File

@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict
class BaseProjection(ABC, torch.nn.Module):
def __init__(self, in_keys: list[str], out_keys: list[str]):
super().__init__()
self.in_keys = in_keys
self.out_keys = out_keys
@abstractmethod
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pass
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.project(policy_params, old_policy_params)

View File

@ -0,0 +1,67 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
proj_chol = torch.linalg.cholesky(proj_cov)
return {"loc": proj_mean, "scale_tril": proj_chol}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
def _gaussian_frobenius(self, p, q):
mean, cov = p
old_mean, old_cov = q
if self.scale_prec:
prec_old = torch.inverse(old_cov)
mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1)
cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1]
else:
mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
return mean_part, cov_part
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
batch_shape = cov.shape[:-2]
cov_mask = cov_part > self.cov_bound
eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1.
eta = torch.max(-eta, eta)
new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov)
return proj_cov

View File

@ -0,0 +1,13 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return policy_params
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
return torch.tensor(0.0, device=next(iter(policy_params.values())).device)

View File

@ -0,0 +1,199 @@
import torch
import cpp_projection
import numpy as np
from .base_projection import BaseProjection
from typing import Dict, Tuple, Any
MAX_EVAL = 1000
def get_numpy(tensor):
return tensor.detach().cpu().numpy()
class KLProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.is_diag = is_diag
self.contextual_std = contextual_std
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
if not self.contextual_std:
std = std[:1]
old_std = old_std[:1]
cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_std = self._cov_projection(std, old_std, cov_part)
if not self.contextual_std:
proj_std = proj_std.expand(mean.shape[0], -1, -1)
return {"loc": proj_mean, "scale_tril": proj_std}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, std = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
return kl.mean() * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, std = p
mean_other, std_other = q
k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
det_term = self._log_determinant(std)
det_term_other = self._log_determinant(std_other)
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
diff = x - y
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1))
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
cov = torch.matmul(std, std.transpose(-1, -2))
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
if self.is_diag:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
except Exception as e:
proj_std = old_std
else:
try:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
if mask.any():
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
failed_mask = failed_mask.bool()
if failed_mask.any():
proj_std[failed_mask] = old_std[failed_mask]
except Exception as e:
import logging
logging.error('Projection failed, taking old cholesky for projection.')
print("Projection failed, taking old cholesky for projection.")
proj_std = old_std
raise e
return proj_std
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
if not KLProjectionGradFunctionCovOnly.projection_op:
KLProjectionGradFunctionCovOnly.projection_op = \
cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
cov, chol, old_chol, eps_cov = args
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = get_numpy(cov)
chol_np = get_numpy(chol)
old_chol_np = get_numpy(old_chol)
eps = get_numpy(eps_cov) * np.ones(batch_shape)
p_op = KLProjectionGradFunctionCovOnly.get_projection_op(batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_chol_np, chol_np, cov_np)
return cov.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_cov, = grad_outputs
d_cov_np = get_numpy(d_cov)
d_cov_np = np.atleast_2d(d_cov_np)
df_stds = projection_op.backward(d_cov_np)
df_stds = np.atleast_2d(df_stds)
df_stds = d_cov.new(df_stds)
return df_stds, None, None, None
class KLProjectionGradFunctionDiagCovOnly(torch.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
KLProjectionGradFunctionDiagCovOnly.projection_op = \
cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionDiagCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
cov, old_cov, eps_cov = args
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = get_numpy(cov)
old_cov_np = get_numpy(old_cov)
eps = get_numpy(eps_cov) * np.ones(batch_shape)
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_cov_np, cov_np)
return cov.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_std, = grad_outputs
d_cov_np = get_numpy(d_std)
d_cov_np = np.atleast_2d(d_cov_np)
df_stds = projection_op.backward(d_cov_np)
df_stds = np.atleast_2d(df_stds)
return d_std.new(df_stds), None, None

View File

@ -0,0 +1,56 @@
import torch
from .base_projection import BaseProjection
from typing import Dict, Tuple
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, sqrt = p
mean_other, sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
return {"loc": proj_mean, "scale_tril": proj_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.norm(diff, dim=-1, keepdim=True)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
diff = sqrt - old_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)

View File

@ -0,0 +1,6 @@
try:
import cpp_projection
except ModuleNotFoundError:
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
else:
from .kl_projection_layer import KLProjectionLayer

View File

@ -1 +1,59 @@
# TODO import pytest
import numpy as np
from fancy_rl import PPO
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_ppo_instantiation():
ppo = PPO("CartPole-v1")
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(simple_env):
ppo = PPO("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_ppo_learn():
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = ppo.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
ppo.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = ppo.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss

77
test/test_trpl.py Normal file
View File

@ -0,0 +1,77 @@
import pytest
import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_trpl_instantiation():
trpl = TRPL("CartPole-v1")
assert isinstance(trpl, TRPL)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
trpl = TRPL(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
gamma=gamma,
max_kl=max_kl
)
assert trpl.learning_rate == learning_rate
assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size
assert trpl.gamma == gamma
assert trpl.max_kl == max_kl
def test_trpl_predict(simple_env):
trpl = TRPL("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_trpl_learn():
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = trpl.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_trpl_training(simple_env):
trpl = TRPL("CartPole-v1", total_timesteps=10000)
initial_performance = evaluate_policy(trpl, simple_env)
trpl.train()
final_performance = evaluate_policy(trpl, simple_env)
assert final_performance > initial_performance, "TRPL should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes

81
test/test_vlearn.py Normal file
View File

@ -0,0 +1,81 @@
import pytest
import torch
import numpy as np
from fancy_rl import VLEARN
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_vlearn_instantiation():
vlearn = VLEARN("CartPole-v1")
assert isinstance(vlearn, VLEARN)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("mean_bound", [0.05, 0.1])
@pytest.mark.parametrize("cov_bound", [0.0005, 0.001])
def test_vlearn_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, mean_bound, cov_bound):
vlearn = VLEARN(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
gamma=gamma,
mean_bound=mean_bound,
cov_bound=cov_bound
)
assert vlearn.learning_rate == learning_rate
assert vlearn.n_steps == n_steps
assert vlearn.batch_size == batch_size
assert vlearn.gamma == gamma
assert vlearn.mean_bound == mean_bound
assert vlearn.cov_bound == cov_bound
def test_vlearn_predict(simple_env):
vlearn = VLEARN("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = vlearn.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_vlearn_learn():
vlearn = VLEARN("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = vlearn.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
vlearn.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = vlearn.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_vlearn_training(simple_env):
vlearn = VLEARN("CartPole-v1", total_timesteps=10000)
initial_performance = evaluate_policy(vlearn, simple_env)
vlearn.train()
final_performance = evaluate_policy(vlearn, simple_env)
assert final_performance > initial_performance, "VLearn should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes