diff --git a/fancy_rl/objectives/__init__.py b/fancy_rl/objectives/__init__.py new file mode 100644 index 0000000..b92ea19 --- /dev/null +++ b/fancy_rl/objectives/__init__.py @@ -0,0 +1 @@ +from fancy_rl.objectives.trpl import TRPLLoss \ No newline at end of file diff --git a/fancy_rl/objectives/trpl.py b/fancy_rl/objectives/trpl.py new file mode 100644 index 0000000..10abc53 --- /dev/null +++ b/fancy_rl/objectives/trpl.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import contextlib + +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Tuple + +import torch +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, +) +from tensordict.utils import NestedKey +from torch import distributions as d + +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _clip_value_loss, + _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) + +from torchrl.objectives.ppo import PPOLoss + +class TRPLLoss(PPOLoss): + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values + + Attributes: + advantage (NestedKey): The input tensordict key where the advantage is expected. + Will be used for the underlying value estimator. Defaults to ``"advantage"``. + value_target (NestedKey): The input tensordict key where the target state value is expected. + Will be used for the underlying value estimator Defaults to ``"value_target"``. + value (NestedKey): The input tensordict key where the state value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_value"``. + sample_log_prob (NestedKey): The input tensordict key where the + sample log probability is expected. Defaults to ``"sample_log_prob"``. + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + advantage: NestedKey = "advantage" + value_target: NestedKey = "value_target" + value: NestedKey = "state_value" + sample_log_prob: NestedKey = "sample_log_prob" + action: NestedKey = "action" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.GAE + + + actor_network: TensorDictModule + critic_network: TensorDictModule + actor_network_params: TensorDictParams + critic_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_critic_network_params: TensorDictParams + + def __init__( + self, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, + trust_region_layer: any | None = None, + entropy_bonus: bool = True, + samples_mc_entropy: int = 1, + entropy_coef: float = 0.01, + critic_coef: float = 1.0, + trust_region_coef: float = 10.0, + loss_critic_type: str = "smooth_l1", + normalize_advantage: bool = False, + gamma: float = None, + separate_losses: bool = False, + reduction: str = None, + clip_value: bool | float | None = None, + **kwargs, + ): + self.trust_region_layer = trust_region_layer + self.trust_region_coef = trust_region_coef + + super(TRPLLoss, self).__init__( + actor_network, + critic_network, + entropy_bonus=entropy_bonus, + samples_mc_entropy=samples_mc_entropy, + entropy_coef=entropy_coef, + critic_coef=critic_coef, + loss_critic_type=loss_critic_type, + normalize_advantage=normalize_advantage, + gamma=gamma, + separate_losses=separate_losses, + reduction=reduction, + clip_value=clip_value, + **kwargs, + ) + + @property + def out_keys(self): + if self._out_keys is None: + keys = ["loss_objective"] + if self.entropy_bonus: + keys.extend(["entropy", "loss_entropy"]) + if self.loss_critic: + keys.append("loss_critic") + keys.append("ESS") + self._out_keys = keys + return self._out_keys + + @out_keys.setter + def out_keys(self, values): + self._out_keys = values + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict = tensordict.clone(False) + advantage = tensordict.get(self.tensor_keys.advantage, None) + if advantage is None: + self.value_estimator( + tensordict, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, + ) + advantage = tensordict.get(self.tensor_keys.advantage) + if self.normalize_advantage and advantage.numel() > 1: + loc = advantage.mean() + scale = advantage.std().clamp_min(1e-6) + advantage = (advantage - loc) / scale + + log_weight, dist, kl_approx = self._log_weight(tensordict) + trust_region_loss_unscaled = self._trust_region_loss(tensordict) + # ESS for logging + with torch.no_grad(): + # In theory, ESS should be computed on particles sampled from the same source. Here we sample according + # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion + # of the weights. + lw = log_weight.squeeze() + ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() + batch = log_weight.shape[0] + + surrogate_gain = log_weight.exp() * advantage + trust_region_loss = trust_region_loss_unscaled * self.trust_region_coef + + loss = -surrogate_gain + trust_region_loss + td_out = TensorDict({"loss_objective": loss}, batch_size=[]) + td_out.set("tr_loss", trust_region_loss) + + if self.entropy_bonus: + entropy = self.get_entropy_bonus(dist) + td_out.set("entropy", entropy.detach().mean()) # for logging + td_out.set("kl_approx", kl_approx.detach().mean()) # for logging + td_out.set("loss_entropy", -self.entropy_coef * entropy) + if self.critic_coef: + loss_critic, value_clip_fraction = self.loss_critic(tensordict) + td_out.set("loss_critic", loss_critic) + if value_clip_fraction is not None: + td_out.set("value_clip_fraction", value_clip_fraction) + + td_out.set("ESS", _reduce(ess, self.reduction) / batch) + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1) + if name.startswith("loss_") + else value, + batch_size=[], + ) + return td_out + + def _trust_region_loss(self, tensordict): + old_distribution = + raw_distribution = + return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution) \ No newline at end of file