Compare commits

...

8 Commits

16 changed files with 445 additions and 222 deletions

View File

@ -52,17 +52,34 @@ To run the test suite:
pytest test/test_ppo.py
```
## TODO
## Status
### Implemented Features
- Proximal Policy Optimization (PPO) algorithm
- Trust Region Policy Layers (TRPL) algorithm (WIP)
- Support for continuous and discrete action spaces
- Multiple projection methods (Rewritten for MIT License Compatability):
- KL Divergence projection
- Frobenius norm projection
- Wasserstein distance projection
- Identity projection (Eq to PPO)
- Configurable neural network architectures for actor and critic
- Logging support (Terminal and WandB, extendable)
### TODO
- [ ] All PPO Tests green
- [ ] Better / more logging
- [ ] Test / Benchmark PPO
- [ ] Refactor Modules for TRPL
- [ ] Get TRPL working
- [ ] Test / Benchmark TRPL
- [ ] All TRPL Tests green
- [ ] Make contextual covariance optional
- [ ] Allow full-cov via chol
- [ ] Test / Benchmark TRPL
- [ ] Write docs / extend README
- [ ] (Implement SAC?)
- [ ] Test func of non-gym envs
- [ ] Implement SAC
- [ ] Implement VLEARN
## Contributing

View File

@ -3,6 +3,7 @@ import gymnasium as gym
from torchrl.envs.libs.gym import GymWrapper
from torchrl.record import VideoRecorder
from abc import ABC
from tensordict import TensorDict
from fancy_rl.loggers import TerminalLogger
@ -53,12 +54,11 @@ class Algo(ABC):
env = GymWrapper(env).to(self.device)
elif callable(env_spec):
env = env_spec()
if isinstance(env, gym.Env):
env = GymWrapper(env).to(self.device)
elif isinstance(env, gym.Env):
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
env = GymWrapper(env).to(self.device)
else:
raise ValueError("env_spec must be a string or a callable that returns an environment.")
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
return env
def train_step(self, batch):
@ -70,6 +70,20 @@ class Algo(ABC):
def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.")
def dump_video(module):
if isinstance(module, VideoRecorder):
module.dump()
def predict(
self,
observation,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
action_td = self.prob_actor(td)
action = action_td["action"]
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -55,7 +55,7 @@ class OnPolicy(Algo):
# Create collector
self.collector = SyncDataCollector(
create_env_fn=lambda: self.make_env(eval=False),
policy=self.actor,
policy=self.prob_actor,
frames_per_batch=self.n_steps,
total_frames=self.total_timesteps,
device=self.device,

View File

@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value.advantages import GAE
from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection # Updated import
from fancy_rl.utils import is_discrete_space
class PPO(OnPolicy):
def __init__(
@ -31,7 +31,10 @@ class PPO(OnPolicy):
device=None,
env_spec_eval=None,
eval_episodes=10,
full_covariance=False,
):
self.clip_range = clip_range
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
@ -41,15 +44,29 @@ class PPO(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
self.discrete = is_discrete_space(act_space)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete:
distribution_class = torch.distributions.Categorical
distribution_kwargs = {"logits": "action_logits"}
else:
if full_covariance:
distribution_class = torch.distributions.MultivariateNormal
in_keys = ["loc", "scale_tril"]
else:
distribution_class = torch.distributions.Normal
in_keys = ["loc", "scale"]
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=in_keys,
out_keys=["action"]
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
@ -86,7 +103,7 @@ class PPO(OnPolicy):
self.loss_module = ClipPPOLoss(
actor_network=self.actor,
critic_network=self.critic,
clip_epsilon=clip_range,
clip_epsilon=self.clip_range,
loss_critic_type='l2',
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,

View File

@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
from fancy_rl.policy import Actor, Critic
from fancy_rl.projections import get_projection, BaseProjection
from fancy_rl.objectives import TRPLLoss
from fancy_rl.utils import is_discrete_space
from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection):
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
super().__init__(
module=combined_module,
in_keys=raw_actor.in_keys,
out_keys=raw_actor.out_keys
)
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection):
super().__init__()
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
def forward(self, tensordict):
raw_params = self.raw_actor(tensordict)
old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
projected_params = self.projection(combined_params)
return projected_params
class TRPL(OnPolicy):
def __init__(
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
device=None,
env_spec_eval=None,
eval_episodes=10,
full_covariance=False,
):
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
obs_space = env.observation_space
act_space = env.action_space
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class
if isinstance(projection_class, str):
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
self.projection = projection_class(
in_keys=["loc", "scale"],
out_keys=["loc", "scale"],
trust_region_bound_mean=trust_region_bound_mean,
trust_region_bound_cov=trust_region_bound_cov
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
mean_bound=trust_region_bound_mean,
cov_bound=trust_region_bound_cov
)
self.actor = ProbabilisticActor(
module=actor_net,
in_keys=["observation"],
out_keys=["loc", "scale"],
distribution_class=torch.distributions.Normal,
return_log_prob=True
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
if full_covariance:
distribution_class = torch.distributions.MultivariateNormal
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
else:
distribution_class = torch.distributions.Normal
distribution_kwargs = {"loc": "loc", "scale": "scale"}
self.prob_actor = ProbabilisticActor(
module=self.actor,
distribution_class=distribution_class,
return_log_prob=True,
in_keys=distribution_kwargs,
)
self.old_actor = deepcopy(self.actor)
self.trust_region_coef = trust_region_coef
self.loss_module = TRPLLoss(
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
)
optimizers = {
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
}
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
)
def update_old_policy(self):
self.old_actor.load_state_dict(self.actor.state_dict())
def project_policy(self, obs):
with torch.no_grad():
old_dist = self.old_actor(obs)
new_dist = self.actor(obs)
projected_params = self.projection.project(new_dist, old_dist)
return projected_params
def pre_update(self, tensordict):
obs = tensordict["observation"]
projected_dist = self.project_policy(obs)
# Update tensordict with projected distribution parameters
tensordict["projected_loc"] = projected_dist[0]
tensordict["projected_scale"] = projected_dist[1]
return tensordict
self.old_actor.load_state_dict(self.raw_actor.state_dict())
def post_update(self):
self.update_old_policy()

View File

@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
def _trust_region_loss(self, tensordict):
old_distribution = self.old_actor_network(tensordict)
raw_distribution = self.actor_network(tensordict)
return self.projection(self.actor_network, raw_distribution, old_distribution)
new_distribution = self.actor_network(tensordict)
return self.projection.get_trust_region_loss(new_distribution, old_distribution)
def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict = tensordict.clone(False)

View File

@ -3,31 +3,63 @@ from tensordict.nn import TensorDictModule
from torchrl.modules import MLP
from tensordict.nn.distributions import NormalParamExtractor
from fancy_rl.utils import is_discrete_space, get_space_shape
from tensordict import TensorDict
class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
self.discrete = is_discrete_space(act_space)
act_space_shape = get_space_shape(act_space)
if is_discrete_space(act_space):
out_features = act_space_shape[-1]
else:
out_features = act_space_shape[-1] * 2
actor_module = nn.Sequential(
MLP(
in_features=get_space_shape(obs_space)[-1],
out_features=out_features,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
),
NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
if self.discrete and full_covariance:
raise ValueError("Full covariance is not applicable for discrete action spaces.")
self.full_covariance = full_covariance
if self.discrete:
out_features = act_space_shape[-1]
out_keys = ["action_logits"]
else:
if full_covariance:
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
out_keys = ["loc", "scale_tril"]
else:
out_features = act_space_shape[-1] * 2
out_keys = ["loc", "scale"]
actor_module = MLP(
in_features=get_space_shape(obs_space)[-1],
out_features=out_features,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),
device=device
).to(device)
if not self.discrete:
if full_covariance:
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
else:
param_extractor = NormalParamExtractor()
actor_module = nn.Sequential(actor_module, param_extractor)
super().__init__(
module=actor_module,
in_keys=["observation"],
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
out_keys=out_keys
)
class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim):
super().__init__()
self.action_dim = action_dim
def forward(self, x):
loc = x[:, :self.action_dim]
scale_tril = torch.zeros(x.shape[0], self.action_dim, self.action_dim, device=x.device)
tril_indices = torch.tril_indices(row=self.action_dim, col=self.action_dim, offset=0)
scale_tril[:, tril_indices[0], tril_indices[1]] = x[:, self.action_dim:]
scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
return TensorDict({"loc": loc, "scale_tril": scale_tril}, batch_size=x.shape[0])
class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
critic_module = MLP(

View File

@ -1,16 +1,71 @@
from abc import ABC, abstractmethod
import torch
from typing import Dict
from torch import nn
from typing import Dict, List
class BaseProjection(ABC, torch.nn.Module):
def __init__(self, in_keys: list[str], out_keys: list[str]):
class BaseProjection(nn.Module, ABC):
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__()
self._validate_in_keys(in_keys)
self._validate_out_keys(out_keys)
self.in_keys = in_keys
self.out_keys = out_keys
self.trust_region_coeff = trust_region_coeff
self.mean_bound = mean_bound
self.cov_bound = cov_bound
self.full_cov = "scale_tril" in in_keys
self.contextual_std = contextual_std
def _validate_in_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys or "old_loc" not in keys:
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
def _validate_out_keys(self, keys: List[str]):
valid_keys = {"loc", "scale", "scale_tril"}
if not set(keys).issubset(valid_keys):
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
if "loc" not in keys:
raise ValueError("'loc' must be included in out_keys")
if "scale" not in keys and "scale_tril" not in keys:
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
@abstractmethod
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pass
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.project(policy_params, old_policy_params)
@abstractmethod
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
pass
def forward(self, tensordict):
policy_params = {}
old_policy_params = {}
for key in self.in_keys:
if key not in tensordict:
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
if key.startswith("old_"):
old_policy_params[key[4:]] = tensordict[key]
else:
policy_params[key] = tensordict[key]
projected_params = self.project(policy_params, old_policy_params)
return projected_params
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
if not self.full_cov:
return torch.diag_embed(params["scale"].pow(2))
else:
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
if not self.full_cov:
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
else:
return torch.linalg.cholesky(cov)

View File

@ -1,33 +1,34 @@
import torch
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict
class FrobeniusProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
cov = self._calc_covariance(policy_params)
old_cov = self._calc_covariance(old_policy_params)
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_cov = self._cov_projection(cov, old_cov, cov_part)
proj_chol = torch.linalg.cholesky(proj_cov)
return {"loc": proj_mean, "scale_tril": proj_chol}
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, chol = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
cov = torch.matmul(chol, chol.transpose(-1, -2))
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
cov = self._calc_covariance(policy_params)
proj_cov = self._calc_covariance(proj_policy_params)
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))

View File

@ -3,8 +3,8 @@ from .base_projection import BaseProjection
from typing import Dict
class IdentityProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return policy_params

View File

@ -2,6 +2,7 @@ import torch
import cpp_projection
import numpy as np
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple, Any
MAX_EVAL = 1000
@ -10,57 +11,65 @@ def get_numpy(tensor):
return tensor.detach().cpu().numpy()
class KLProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
self.is_diag = is_diag
self.contextual_std = contextual_std
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, std = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
if not self.contextual_std:
std = std[:1]
old_std = old_std[:1]
scale_or_tril = scale_or_tril[:1]
old_scale_or_tril = old_scale_or_tril[:1]
cov_part = cov_part[:1]
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_std = self._cov_projection(std, old_std, cov_part)
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
if not self.contextual_std:
proj_std = proj_std.expand(mean.shape[0], -1, -1)
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
return {"loc": proj_mean, "scale_tril": proj_std}
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, std = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
return kl.mean() * self.trust_region_coeff
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
mean, std = p
mean_other, std_other = q
mean, scale_or_tril = p
mean_other, scale_or_tril_other = q
k = mean.shape[-1]
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
det_term = self._log_determinant(std)
det_term_other = self._log_determinant(std_other)
det_term = self._log_determinant(scale_or_tril)
det_term_other = self._log_determinant(scale_or_tril_other)
if self.full_cov:
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
else:
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
return maha_part, cov_part
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
diff = x - y
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
if self.full_cov:
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
else:
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
if self.full_cov:
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
else:
return 2 * torch.log(scale_or_tril).sum(-1)
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
return torch.sum(x.pow(2), dim=(-2, -1))
@ -68,49 +77,45 @@ class KLProjection(BaseProjection):
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
cov = torch.matmul(std, std.transpose(-1, -2))
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
if self.is_diag:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
try:
if mask.any():
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
old_cov.diagonal(dim1=-2, dim2=-1),
self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
mask &= ~is_invalid
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
except Exception as e:
proj_std = old_std
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if self.full_cov:
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
else:
try:
mask = cov_part > self.cov_bound
proj_std = torch.zeros_like(std)
proj_std[~mask] = std[~mask]
if mask.any():
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
cov = scale_or_tril.pow(2)
old_cov = old_scale_or_tril.pow(2)
mask = cov_part > self.cov_bound
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
try:
if mask.any():
if self.full_cov:
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
if is_invalid.any():
proj_std[is_invalid] = old_std[is_invalid]
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
failed_mask = failed_mask.bool()
if failed_mask.any():
proj_std[failed_mask] = old_std[failed_mask]
except Exception as e:
import logging
logging.error('Projection failed, taking old cholesky for projection.')
print("Projection failed, taking old cholesky for projection.")
proj_std = old_std
raise e
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
else:
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
if is_invalid.any():
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
mask &= ~is_invalid
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
except Exception as e:
import logging
logging.error('Projection failed, taking old scale_or_tril for projection.')
print("Projection failed, taking old scale_or_tril for projection.")
proj_scale_or_tril = old_scale_or_tril
raise e
return proj_std
return proj_scale_or_tril
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):

View File

@ -1,56 +1,86 @@
import torch
from .base_projection import BaseProjection
from tensordict.nn import TensorDictModule
from typing import Dict, Tuple
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
"""
'Converts' scale_tril to scale_sqrt.
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
"""
return scale_tril
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
mean, sqrt = p
mean_other, sqrt_other = q
mean, scale_or_sqrt = p
mean_other, scale_or_sqrt_other = q
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
cov = scale_or_sqrt.pow(2)
cov_other = scale_or_sqrt_other.pow(2)
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = 1 / scale_or_sqrt_other
c = sqrt_inv_other.pow(2) * cov
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
else:
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
else: # Full covariance case
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
if scale_prec:
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
else:
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
return mean_part, cov_part
class WassersteinProjection(BaseProjection):
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
self.scale_prec = scale_prec
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
mean = policy_params["loc"]
old_mean = old_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
proj_mean = self._mean_projection(mean, old_mean, mean_part)
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
return {"loc": proj_mean, "scale_tril": proj_sqrt}
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
mean = policy_params["loc"]
proj_mean = proj_policy_params["loc"]
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
w2 = mean_part + cov_part
return w2.mean() * self.trust_region_coeff
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
diff = mean - old_mean
norm = torch.norm(diff, dim=-1, keepdim=True)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
norm = torch.sqrt(mean_part)
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
diff = sqrt - old_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.sqrt(cov_part)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
else: # Full covariance case
diff = scale_or_sqrt - old_scale_or_sqrt
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)

View File

@ -33,24 +33,28 @@ def is_discrete_space(action_space):
raise ValueError(f"Unsupported action space type: {type(action_space)}")
def get_space_shape(action_space):
if gym_available:
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types = (GymBox,)
else:
discrete_types = ()
continuous_types = ()
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
if gym_available:
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
continuous_types += (GymBox,)
if isinstance(action_space, discrete_types):
if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
return (action_space.n,)
elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
return (sum(action_space.nvec),)
elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
return (action_space.n,)
elif gym_available:
if isinstance(action_space, GymDiscrete):
return (action_space.n,)
elif isinstance(action_space, GymMultiDiscrete):
return (sum(action_space.nvec),)
elif isinstance(action_space, GymMultiBinary):
return (action_space.n,)
elif isinstance(action_space, continuous_types):
return action_space.shape

View File

@ -3,12 +3,15 @@ import numpy as np
from fancy_rl import PPO
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
return gym.make('LunarLander-v2', continuous=True)
def test_ppo_instantiation():
ppo = PPO("CartPole-v1")
ppo = PPO(simple_env)
assert isinstance(ppo, PPO)
def test_ppo_instantiation_from_str():
ppo = PPO('CartPole-v1')
assert isinstance(ppo, PPO)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@ -19,7 +22,7 @@ def test_ppo_instantiation():
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
ppo = PPO(
"CartPole-v1",
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
assert ppo.gamma == gamma
assert ppo.clip_range == clip_range
def test_ppo_predict(simple_env):
ppo = PPO("CartPole-v1")
obs, _ = simple_env.reset()
def test_ppo_predict():
ppo = PPO(simple_env)
env = ppo.make_env()
obs, _ = env.reset()
action, _ = ppo.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
assert action.shape == env.action_space.shape
def test_ppo_learn():
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
ppo = PPO(simple_env, n_steps=64, batch_size=32)
env = ppo.make_env()
obs, _ = env.reset()
for _ in range(64):
action, _ = ppo.predict(obs)
next_obs, reward, done, truncated, _ = env.step(action)
ppo.store_transition(obs, action, reward, done, next_obs)
obs = next_obs
obs, reward, done, truncated, _ = env.step(action)
if done or truncated:
obs, _ = env.reset()
loss = ppo.learn()
assert isinstance(loss, dict)
assert "policy_loss" in loss
assert "value_loss" in loss
def test_ppo_training():
ppo = PPO(simple_env, total_timesteps=10000)
env = ppo.make_env()
initial_performance = evaluate_policy(ppo, env)
ppo.train()
final_performance = evaluate_policy(ppo, env)
assert final_performance > initial_performance, "PPO should improve performance after training"
def evaluate_policy(policy, env, n_eval_episodes=10):
total_reward = 0
for _ in range(n_eval_episodes):
obs, _ = env.reset()
done = False
while not done:
action, _ = policy.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
total_reward += reward
done = terminated or truncated
return total_reward / n_eval_episodes

View File

@ -3,12 +3,15 @@ import numpy as np
from fancy_rl import TRPL
import gymnasium as gym
@pytest.fixture
def simple_env():
return gym.make('CartPole-v1')
return gym.make('LunarLander-v2', continuous=True)
def test_trpl_instantiation():
trpl = TRPL("CartPole-v1")
trpl = TRPL(simple_env)
assert isinstance(trpl, TRPL)
def test_trpl_instantiation_from_str():
trpl = TRPL('MountainCarContinuous-v0')
assert isinstance(trpl, TRPL)
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
@ -19,7 +22,7 @@ def test_trpl_instantiation():
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
trpl = TRPL(
"CartPole-v1",
simple_env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
def test_trpl_predict(simple_env):
trpl = TRPL("CartPole-v1")
obs, _ = simple_env.reset()
def test_trpl_predict():
trpl = TRPL(simple_env)
env = trpl.make_env()
obs, _ = env.reset()
action, _ = trpl.predict(obs)
assert isinstance(action, np.ndarray)
assert action.shape == simple_env.action_space.shape
assert action.shape == env.action_space.shape
def test_trpl_learn():
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
env = gym.make("CartPole-v1")
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
env = trpl.make_env()
obs, _ = env.reset()
for _ in range(64):
action, _ = trpl.predict(obs)
@ -58,12 +62,13 @@ def test_trpl_learn():
assert "policy_loss" in loss
assert "value_loss" in loss
def test_trpl_training(simple_env):
trpl = TRPL("CartPole-v1", total_timesteps=10000)
def test_trpl_training():
trpl = TRPL(simple_env, total_timesteps=10000)
env = trpl.make_env()
initial_performance = evaluate_policy(trpl, simple_env)
initial_performance = evaluate_policy(trpl, env)
trpl.train()
final_performance = evaluate_policy(trpl, simple_env)
final_performance = evaluate_policy(trpl, env)
assert final_performance > initial_performance, "TRPL should improve performance after training"