Refactor objective losses

This commit is contained in:
Dominik Moritz Roth 2024-08-28 11:32:51 +02:00
parent 25988bab54
commit dd98af9f77

View File

@ -38,100 +38,40 @@ from torchrl.objectives.value import (
) )
from torchrl.objectives.ppo import PPOLoss from torchrl.objectives.ppo import PPOLoss
from fancy_rl.projections import get_projection
class TRPLLoss(PPOLoss): class TRPLLoss(PPOLoss):
@dataclass
class _AcceptedKeys:
"""Maintains default values for all configurable tensordict keys.
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
default values
Attributes:
advantage (NestedKey): The input tensordict key where the advantage is expected.
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
value_target (NestedKey): The input tensordict key where the target state value is expected.
Will be used for the underlying value estimator Defaults to ``"value_target"``.
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the
sample log probability is expected. Defaults to ``"sample_log_prob"``.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Will be used for the underlying value estimator. Defaults to ``"reward"``.
done (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is done. Will be used for the underlying value estimator.
Defaults to ``"done"``.
terminated (NestedKey): The key in the input TensorDict that indicates
whether a trajectory is terminated. Will be used for the underlying value estimator.
Defaults to ``"terminated"``.
"""
advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.GAE
actor_network: TensorDictModule
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams
def __init__( def __init__(
self, self,
actor_network: ProbabilisticTensorDictSequential | None = None, actor_network: ProbabilisticTensorDictSequential,
critic_network: TensorDictModule | None = None, old_actor_network: ProbabilisticTensorDictSequential,
trust_region_layer: any | None = None, critic_network: TensorDictModule,
entropy_bonus: bool = True, projection: any,
samples_mc_entropy: int = 1,
entropy_coef: float = 0.01, entropy_coef: float = 0.01,
critic_coef: float = 1.0, critic_coef: float = 1.0,
trust_region_coef: float = 10.0, trust_region_coef: float = 10.0,
loss_critic_type: str = "smooth_l1",
normalize_advantage: bool = False, normalize_advantage: bool = False,
gamma: float = None,
separate_losses: bool = False,
reduction: str = None,
clip_value: bool | float | None = None,
**kwargs, **kwargs,
): ):
self.trust_region_layer = trust_region_layer super().__init__(
self.trust_region_coef = trust_region_coef actor_network=actor_network,
critic_network=critic_network,
super(TRPLLoss, self).__init__(
actor_network,
critic_network,
entropy_bonus=entropy_bonus,
samples_mc_entropy=samples_mc_entropy,
entropy_coef=entropy_coef, entropy_coef=entropy_coef,
critic_coef=critic_coef, critic_coef=critic_coef,
loss_critic_type=loss_critic_type,
normalize_advantage=normalize_advantage, normalize_advantage=normalize_advantage,
gamma=gamma,
separate_losses=separate_losses,
reduction=reduction,
clip_value=clip_value,
**kwargs, **kwargs,
) )
self.old_actor_network = old_actor_network
self.projection = projection
self.trust_region_coef = trust_region_coef
@property @property
def out_keys(self): def out_keys(self):
if self._out_keys is None: if self._out_keys is None:
keys = ["loss_objective"] keys = ["loss_objective", "tr_loss"]
if self.entropy_bonus: if self.entropy_bonus:
keys.extend(["entropy", "loss_entropy"]) keys.extend(["entropy", "loss_entropy"])
if self.loss_critic: if self.critic_coef:
keys.append("loss_critic") keys.append("loss_critic")
keys.append("ESS") keys.append("ESS")
self._out_keys = keys self._out_keys = keys
@ -141,8 +81,12 @@ class TRPLLoss(PPOLoss):
def out_keys(self, values): def out_keys(self, values):
self._out_keys = values self._out_keys = values
@dispatch def _trust_region_loss(self, tensordict):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: old_distribution = self.old_actor_network(tensordict)
raw_distribution = self.actor_network(tensordict)
return self.projection(self.actor_network, raw_distribution, old_distribution)
def forward(self, tensordict: TensorDictBase) -> TensorDict:
tensordict = tensordict.clone(False) tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None) advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None: if advantage is None:
@ -159,11 +103,8 @@ class TRPLLoss(PPOLoss):
log_weight, dist, kl_approx = self._log_weight(tensordict) log_weight, dist, kl_approx = self._log_weight(tensordict)
trust_region_loss_unscaled = self._trust_region_loss(tensordict) trust_region_loss_unscaled = self._trust_region_loss(tensordict)
# ESS for logging
with torch.no_grad(): with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
lw = log_weight.squeeze() lw = log_weight.squeeze()
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
batch = log_weight.shape[0] batch = log_weight.shape[0]
@ -194,8 +135,3 @@ class TRPLLoss(PPOLoss):
batch_size=[], batch_size=[],
) )
return td_out return td_out
def _trust_region_loss(self, tensordict):
old_distribution =
raw_distribution =
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)