Compare commits

..

No commits in common. "9c55b6a110a179caa594129e9f963f0b4e622be7" and "4240f611ac459670f9a89fb9a750d7cb04fe9aa2" have entirely different histories.

23 changed files with 123 additions and 916 deletions

View File

@ -4,7 +4,6 @@ try:
except ImportError:
pass
from fancy_rl.algos import PPO, TRPL, VLEARN
from fancy_rl.projections import get_projection
from fancy_rl.algos import PPO
__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]
__all__ = ["PPO"]

View File

@ -1,3 +1 @@
from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
from fancy_rl.algos.vlearn import VLEARN
from fancy_rl.algos.ppo import PPO

View File

@ -15,17 +15,23 @@ class OnPolicy(Algo):
env_spec,
optimizers,
loggers=None,
actor_hidden_sizes=[64, 64],
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
clip_range=0.2,
env_spec_eval=None,
eval_episodes=10,
device=None,
@ -71,25 +77,15 @@ class OnPolicy(Algo):
batch_size=self.batch_size,
)
def pre_process_batch(self, batch):
return batch
def post_process_batch(self, batch):
pass
def train_step(self, batch):
batch = self.pre_process_batch(batch)
for optimizer in self.optimizers.values():
optimizer.zero_grad()
losses = self.loss_module(batch)
loss = sum(losses.values()) # Sum all losses
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
loss.backward()
for optimizer in self.optimizers.values():
optimizer.step()
self.post_process_batch(batch)
return loss
def train(self):

View File

@ -4,7 +4,6 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
class PPO(OnPolicy):
def __init__(

View File

@ -1,16 +1,9 @@
import torch
from torch import nn
from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from torchrl.objectives.value import GAE
from torchrl.modules import ProbabilisticActor
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
class TRPL(OnPolicy):
def __init__(
@ -21,21 +14,19 @@ class TRPL(OnPolicy):
critic_hidden_sizes=[64, 64],
actor_activation_fn="Tanh",
critic_activation_fn="Tanh",
proj_layer_type=None,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
projection_class="identity_projection",
trust_region_coef=10.0,
trust_region_bound_mean=0.1,
trust_region_bound_cov=0.001,
total_timesteps=1e6,
eval_interval=2048,
eval_deterministic=True,
entropy_coef=0.01,
critic_coef=0.5,
trust_region_coef=10.0,
normalize_advantage=False,
device=None,
env_spec_eval=None,
@ -44,6 +35,9 @@ class TRPL(OnPolicy):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.trust_region_layer = None # TODO: from proj_layer_type
self.trust_region_coef = trust_region_coef
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
env = self.make_env()
@ -52,40 +46,14 @@ class TRPL(OnPolicy):
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
# Handle projection_class
if isinstance(projection_class, str):
projection_class = get_projection(projection_class)
elif not issubclass(projection_class, BaseProjection):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"],
out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
raw_actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
actor_network=self.actor,
old_actor_network=self.old_actor,
critic_network=self.critic,
projection=self.projection,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
trust_region_coef=trust_region_coef,
normalize_advantage=normalize_advantage,
)
self.actor = raw_actor # TODO: Proj here
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -111,6 +79,7 @@ class TRPL(OnPolicy):
env_spec_eval=env_spec_eval,
eval_episodes=eval_episodes,
)
self.adv_module = GAE(
gamma=self.gamma,
lmbda=gae_lambda,
@ -118,24 +87,13 @@ class TRPL(OnPolicy):
average_gae=False,
)
def update_old_policy(self):
self.old_actor.load_state_dict(self.actor.state_dict())
def project_policy(self, obs):
with torch.no_grad():
old_dist = self.old_actor(obs)
new_dist = self.actor(obs)
projected_params = self.projection.project(new_dist, old_dist)
return projected_params
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
def post_update(self):
self.update_old_policy()
self.loss_module = TRPLLoss(
actor_network=self.actor,
critic_network=self.critic,
trust_region_layer=self.trust_region_layer,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
trust_region_coef=self.trust_region_coef,
normalize_advantage=self.normalize_advantage,
)

View File

@ -1,114 +0,0 @@
import torch
from torch import nn
from typing import Dict, Any, Optional
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from fancy_rl.utils import get_env, get_actor, get_critic
from fancy_rl.modules.vlearn_loss import VLEARNLoss
from fancy_rl.modules.projection import get_vlearn_projection
from fancy_rl.modules.squashed_normal import get_squashed_normal
class VLEARN:
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):
self.device = torch.device(device)
self.env = get_env(env_id)
self.projection = get_vlearn_projection(**kwargs.get("projection", {}))
actor = get_actor(self.env, **kwargs.get("actor", {}))
self.actor = ProbabilisticActor(
actor,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=get_squashed_normal(),
return_log_prob=True
).to(self.device)
self.old_actor = self.actor.clone()
self.critic = ValueOperator(
module=get_critic(self.env, **kwargs.get("critic", {})),
in_keys=["observation"]
).to(self.device)
self.collector = SyncDataCollector(
self.env,
self.actor,
frames_per_batch=kwargs.get("frames_per_batch", 1000),
total_frames=kwargs.get("total_frames", -1),
device=self.device,
)
self.replay_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(kwargs.get("buffer_size", 100000)),
batch_size=kwargs.get("batch_size", 256),
)
self.loss_module = VLEARNLoss(
actor_network=self.actor,
critic_network=self.critic,
old_actor_network=self.old_actor,
projection=self.projection,
**kwargs.get("loss", {})
)
self.optimizers = nn.ModuleDict({
"policy": torch.optim.Adam(self.actor.parameters(), lr=kwargs.get("lr_policy", 3e-4)),
"critic": torch.optim.Adam(self.critic.parameters(), lr=kwargs.get("lr_critic", 3e-4))
})
self.update_policy_interval = kwargs.get("update_policy_interval", 1)
self.update_critic_interval = kwargs.get("update_critic_interval", 1)
self.target_update_interval = kwargs.get("target_update_interval", 1)
self.polyak_weight_critic = kwargs.get("polyak_weight_critic", 0.995)
def train(self, num_iterations: int = 1000) -> None:
for i in range(num_iterations):
data = next(self.collector)
self.replay_buffer.extend(data)
batch = self.replay_buffer.sample().to(self.device)
loss_dict = self.loss_module(batch)
if i % self.update_policy_interval == 0:
self.optimizers["policy"].zero_grad()
loss_dict["policy_loss"].backward()
self.optimizers["policy"].step()
if i % self.update_critic_interval == 0:
self.optimizers["critic"].zero_grad()
loss_dict["critic_loss"].backward()
self.optimizers["critic"].step()
if i % self.target_update_interval == 0:
self.critic.update_target_params(self.polyak_weight_critic)
self.old_actor.load_state_dict(self.actor.state_dict())
self.collector.update_policy_weights_()
if i % 100 == 0:
eval_reward = self.eval()
print(f"Iteration {i}, Eval reward: {eval_reward}")
def eval(self, num_episodes: int = 10) -> float:
total_reward = 0
for _ in range(num_episodes):
td = self.env.reset()
done = False
while not done:
with torch.no_grad():
action = self.actor(td.to(self.device))["action"]
td = self.env.step(action)
total_reward += td["reward"].item()
done = td["done"].item()
return total_reward / num_episodes
def save_policy(self, path: str) -> None:
torch.save(self.actor.state_dict(), f"{path}/actor.pth")
torch.save(self.critic.state_dict(), f"{path}/critic.pth")
def load_policy(self, path: str) -> None:
self.actor.load_state_dict(torch.load(f"{path}/actor.pth"))
self.critic.load_state_dict(torch.load(f"{path}/critic.pth"))
self.old_actor.load_state_dict(self.actor.state_dict())

View File

@ -38,40 +38,100 @@ from torchrl.objectives.value import (
)
from torchrl.objectives.ppo import PPOLoss
from fancy_rl.projections import get_projection
class TRPLLoss(PPOLoss):
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the
sample log probability is expected. Defaults to ``"sample_log_prob"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""
advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.GAE
actor_network: TensorDictModule
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams
def __init__(
self,
actor_network: ProbabilisticTensorDictSequential,
old_actor_network: ProbabilisticTensorDictSequential,
critic_network: TensorDictModule,
projection: any,
actor_network: ProbabilisticTensorDictSequential | None = None,
critic_network: TensorDictModule | None = None,
trust_region_layer: any | None = None,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01,
critic_coef: float = 1.0,
trust_region_coef: float = 10.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False,
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
clip_value: bool | float | None = None,
**kwargs,
):
super().__init__(
actor_network=actor_network,
critic_network=critic_network,
self.trust_region_layer = trust_region_layer
self.trust_region_coef = trust_region_coef
super(TRPLLoss, self).__init__(
actor_network,
critic_network,
entropy_bonus=entropy_bonus,
samples_mc_entropy=samples_mc_entropy,
entropy_coef=entropy_coef,
critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
clip_value=clip_value,
**kwargs,
)
self.old_actor_network = old_actor_network
self.projection = projection
self.trust_region_coef = trust_region_coef
@property
def out_keys(self):
if self._out_keys is None:
keys = ["loss_objective", "tr_loss"]
keys = ["loss_objective"]
if self.entropy_bonus:
keys.extend(["entropy", "loss_entropy"])
if self.critic_coef:
if self.loss_critic:
keys.append("loss_critic")
keys.append("ESS")
self._out_keys = keys
@ -81,12 +141,8 @@ class TRPLLoss(PPOLoss):
def out_keys(self, values):
self._out_keys = values
def _trust_region_loss(self, tensordict):
old_distribution = self.old_actor_network(tensordict)
raw_distribution = self.actor_network(tensordict)
return self.projection(self.actor_network, raw_distribution, old_distribution)
def forward(self, tensordict: TensorDictBase) -> TensorDict:
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
@ -103,8 +159,11 @@ class TRPLLoss(PPOLoss):
log_weight, dist, kl_approx = self._log_weight(tensordict)
trust_region_loss_unscaled = self._trust_region_loss(tensordict)
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
lw = log_weight.squeeze()
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
batch = log_weight.shape[0]
@ -135,3 +194,8 @@ class TRPLLoss(PPOLoss):
batch_size=[],
)
return td_out
def _trust_region_loss(self, tensordict):
old_distribution =
raw_distribution =
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)

View File

@ -1,106 +0,0 @@
import torch
from torchrl.objectives import LossModule
from torch.distributions import Normal
class VLEARNLoss(LossModule):
def __init__(
self,
actor_network,
critic_network,
old_actor_network,
gamma=0.99,
lmbda=0.95,
entropy_coef=0.01,
critic_coef=0.5,
normalize_advantage=True,
eps=1e-8,
delta=0.1
):
super().__init__()
self.actor_network = actor_network
self.critic_network = critic_network
self.old_actor_network = old_actor_network
self.gamma = gamma
self.lmbda = lmbda
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.eps = eps
self.delta = delta
def forward(self, tensordict):
# Compute returns and advantages
with torch.no_grad():
returns = self.compute_returns(tensordict)
values = self.critic_network(tensordict)["state_value"]
advantages = returns - values
if self.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute actor loss
new_td = self.actor_network(tensordict)
old_td = self.old_actor_network(tensordict)
new_dist = Normal(new_td["loc"], new_td["scale"])
old_dist = Normal(old_td["loc"], old_td["scale"])
new_log_prob = new_dist.log_prob(tensordict["action"]).sum(-1)
old_log_prob = old_dist.log_prob(tensordict["action"]).sum(-1)
ratio = torch.exp(new_log_prob - old_log_prob)
# Compute projection
kl = torch.distributions.kl.kl_divergence(new_dist, old_dist).sum(-1)
alpha = torch.where(kl > self.delta,
torch.sqrt(self.delta / (kl + self.eps)),
torch.ones_like(kl))
proj_loc = alpha.unsqueeze(-1) * new_td["loc"] + (1 - alpha.unsqueeze(-1)) * old_td["loc"]
proj_scale = torch.sqrt(alpha.unsqueeze(-1)**2 * new_td["scale"]**2 + (1 - alpha.unsqueeze(-1))**2 * old_td["scale"]**2)
proj_dist = Normal(proj_loc, proj_scale)
proj_log_prob = proj_dist.log_prob(tensordict["action"]).sum(-1)
proj_ratio = torch.exp(proj_log_prob - old_log_prob)
policy_loss = -torch.min(
ratio * advantages,
proj_ratio * advantages
).mean()
# Compute critic loss
value_pred = self.critic_network(tensordict)["state_value"]
critic_loss = 0.5 * (returns - value_pred).pow(2).mean()
# Compute entropy loss
entropy_loss = -self.entropy_coef * new_dist.entropy().mean()
# Combine losses
loss = policy_loss + self.critic_coef * critic_loss + entropy_loss
return {
"loss": loss,
"policy_loss": policy_loss,
"critic_loss": critic_loss,
"entropy_loss": entropy_loss,
}
def compute_returns(self, tensordict):
rewards = tensordict["reward"]
dones = tensordict["done"]
values = self.critic_network(tensordict)["state_value"]
returns = torch.zeros_like(rewards)
advantages = torch.zeros_like(rewards)
last_gae_lam = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
advantages[t] = last_gae_lam = delta + self.gamma * self.lmbda * (1 - dones[t]) * last_gae_lam
returns = advantages + values
return returns

View File

@ -1,20 +1,6 @@
from .base_projection import BaseProjection
from .identity_projection import IdentityProjection
from .kl_projection import KLProjection
from .wasserstein_projection import WassersteinProjection
from .frobenius_projection import FrobeniusProjection
def get_projection(projection_name: str):
projections = {
"identity_projection": IdentityProjection,
"kl_projection": KLProjection,
"wasserstein_projection": WassersteinProjection,
"frobenius_projection": FrobeniusProjection,
}
projection = projections.get(projection_name.lower())
if projection is None:
raise ValueError(f"Unknown projection: {projection_name}")
return projection
__all__ = ["BaseProjection", "IdentityProjection", "KLProjection", "WassersteinProjection", "FrobeniusProjection", "get_projection"]
try:
import cpp_projection
except ModuleNotFoundError:
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
else:
from .kl_projection_layer import KLProjectionLayer

View File

@ -1,16 +0,0 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict
class BaseProjection(ABC, torch.nn.Module):
def __init__(self, in_keys: list[str], out_keys: list[str]):
super().__init__()
self.in_keys = in_keys
self.out_keys = out_keys
@abstractmethod
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pass
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.project(policy_params, old_policy_params)

View File

@ -1,67 +0,0 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
proj_chol = torch.linalg.cholesky(proj_cov)
return {"loc": proj_mean, "scale_tril": proj_chol}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
def _gaussian_frobenius(self, p, q):
mean, cov = p
old_mean, old_cov = q
if self.scale_prec:
prec_old = torch.inverse(old_cov)
mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1)
cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1]
else:
mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
return mean_part, cov_part
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
batch_shape = cov.shape[:-2]
cov_mask = cov_part > self.cov_bound
eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1.
eta = torch.max(-eta, eta)
new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov)
return proj_cov

View File

@ -1,13 +0,0 @@
import torch
from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return policy_params
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
return torch.tensor(0.0, device=next(iter(policy_params.values())).device)

View File

@ -1,199 +0,0 @@
import torch
import cpp_projection
import numpy as np
from .base_projection import BaseProjection
from typing import Dict, Tuple, Any
MAX_EVAL = 1000
def get_numpy(tensor):
return tensor.detach().cpu().numpy()
class KLProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.is_diag = is_diag
self.contextual_std = contextual_std
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
if not self.contextual_std:
std = std[:1]
old_std = old_std[:1]
cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_std = self._cov_projection(std, old_std, cov_part)
if not self.contextual_std:
proj_std = proj_std.expand(mean.shape[0], -1, -1)
return {"loc": proj_mean, "scale_tril": proj_std}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, std = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
return kl.mean() * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, std = p
mean_other, std_other = q
k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
det_term = self._log_determinant(std)
det_term_other = self._log_determinant(std_other)
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
diff = x - y
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1))
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
cov = torch.matmul(std, std.transpose(-1, -2))
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
if self.is_diag:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
except Exception as e:
proj_std = old_std
else:
try:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
if mask.any():
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
failed_mask = failed_mask.bool()
if failed_mask.any():
proj_std[failed_mask] = old_std[failed_mask]
except Exception as e:
import logging
logging.error('Projection failed, taking old cholesky for projection.')
print("Projection failed, taking old cholesky for projection.")
proj_std = old_std
raise e
return proj_std
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
if not KLProjectionGradFunctionCovOnly.projection_op:
KLProjectionGradFunctionCovOnly.projection_op = \
cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
cov, chol, old_chol, eps_cov = args
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = get_numpy(cov)
chol_np = get_numpy(chol)
old_chol_np = get_numpy(old_chol)
eps = get_numpy(eps_cov) * np.ones(batch_shape)
p_op = KLProjectionGradFunctionCovOnly.get_projection_op(batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_chol_np, chol_np, cov_np)
return cov.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_cov, = grad_outputs
d_cov_np = get_numpy(d_cov)
d_cov_np = np.atleast_2d(d_cov_np)
df_stds = projection_op.backward(d_cov_np)
df_stds = np.atleast_2d(df_stds)
df_stds = d_cov.new(df_stds)
return df_stds, None, None, None
class KLProjectionGradFunctionDiagCovOnly(torch.autograd.Function):
projection_op = None
@staticmethod
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
KLProjectionGradFunctionDiagCovOnly.projection_op = \
cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
return KLProjectionGradFunctionDiagCovOnly.projection_op
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
cov, old_cov, eps_cov = args
batch_shape = cov.shape[0]
dim = cov.shape[-1]
cov_np = get_numpy(cov)
old_cov_np = get_numpy(old_cov)
eps = get_numpy(eps_cov) * np.ones(batch_shape)
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
ctx.proj = p_op
proj_std = p_op.forward(eps, old_cov_np, cov_np)
return cov.new(proj_std)
@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
projection_op = ctx.proj
d_std, = grad_outputs
d_cov_np = get_numpy(d_std)
d_cov_np = np.atleast_2d(d_cov_np)
df_stds = projection_op.backward(d_cov_np)
df_stds = np.atleast_2d(df_stds)
return d_std.new(df_stds), None, None

View File

@ -1,56 +0,0 @@
import torch
from .base_projection import BaseProjection
from typing import Dict, Tuple
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, sqrt = p
mean_other, sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
return {"loc": proj_mean, "scale_tril": proj_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.norm(diff, dim=-1, keepdim=True)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
diff = sqrt - old_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)

View File

@ -1,6 +0,0 @@
try:
import cpp_projection
except ModuleNotFoundError:
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
else:
from .kl_projection_layer import KLProjectionLayer

View File

@ -1,59 +1 @@
import pytest
import numpy as np
from fancy_rl import PPO
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_ppo_instantiation():
ppo = PPO("CartPole-v1")
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("n_epochs", [5, 10])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
clip_range=clip_range
)
assert ppo.learning_rate == learning_rate
assert ppo.n_steps == n_steps
assert ppo.batch_size == batch_size
assert ppo.n_epochs == n_epochs
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(simple_env):
ppo = PPO("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_ppo_learn():
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = ppo.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
ppo.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = ppo.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
# TODO

View File

@ -1,77 +0,0 @@
import pytest
import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_trpl_instantiation():
trpl = TRPL("CartPole-v1")
assert isinstance(trpl, TRPL)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
trpl = TRPL(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
gamma=gamma,
max_kl=max_kl
)
assert trpl.learning_rate == learning_rate
assert trpl.n_steps == n_steps
assert trpl.batch_size == batch_size
assert trpl.gamma == gamma
assert trpl.max_kl == max_kl
def test_trpl_predict(simple_env):
trpl = TRPL("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_trpl_learn():
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = trpl.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
trpl.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = trpl.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_trpl_training(simple_env):
trpl = TRPL("CartPole-v1", total_timesteps=10000)
initial_performance = evaluate_policy(trpl, simple_env)
trpl.train()
final_performance = evaluate_policy(trpl, simple_env)
assert final_performance > initial_performance, "TRPL should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes

View File

@ -1,81 +0,0 @@
import pytest
import torch
import numpy as np
from fancy_rl import VLEARN
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
def test_vlearn_instantiation():
vlearn = VLEARN("CartPole-v1")
assert isinstance(vlearn, VLEARN)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@pytest.mark.parametrize("n_steps", [1024, 2048])
@pytest.mark.parametrize("batch_size", [32, 64, 128])
@pytest.mark.parametrize("gamma", [0.95, 0.99])
@pytest.mark.parametrize("mean_bound", [0.05, 0.1])
@pytest.mark.parametrize("cov_bound", [0.0005, 0.001])
def test_vlearn_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, mean_bound, cov_bound):
vlearn = VLEARN(
"CartPole-v1",
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
gamma=gamma,
mean_bound=mean_bound,
cov_bound=cov_bound
)
assert vlearn.learning_rate == learning_rate
assert vlearn.n_steps == n_steps
assert vlearn.batch_size == batch_size
assert vlearn.gamma == gamma
assert vlearn.mean_bound == mean_bound
assert vlearn.cov_bound == cov_bound
def test_vlearn_predict(simple_env):
vlearn = VLEARN("CartPole-v1")
obs, _ = simple_env.reset()
action, _ = vlearn.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
def test_vlearn_learn():
vlearn = VLEARN("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
obs, _ = env.reset()
for _ in range(64):
action, _ = vlearn.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
vlearn.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
if done or truncated:
obs, _ = env.reset()
loss = vlearn.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_vlearn_training(simple_env):
vlearn = VLEARN("CartPole-v1", total_timesteps=10000)
initial_performance = evaluate_policy(vlearn, simple_env)
vlearn.train()
final_performance = evaluate_policy(vlearn, simple_env)
assert final_performance > initial_performance, "VLearn should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes