Refactor objective losses
This commit is contained in:
parent
25988bab54
commit
dd98af9f77
@ -38,100 +38,40 @@ from torchrl.objectives.value import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from torchrl.objectives.ppo import PPOLoss
|
from torchrl.objectives.ppo import PPOLoss
|
||||||
|
from fancy_rl.projections import get_projection
|
||||||
|
|
||||||
class TRPLLoss(PPOLoss):
|
class TRPLLoss(PPOLoss):
|
||||||
@dataclass
|
|
||||||
class _AcceptedKeys:
|
|
||||||
"""Maintains default values for all configurable tensordict keys.
|
|
||||||
|
|
||||||
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
|
||||||
default values
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
advantage (NestedKey): The input tensordict key where the advantage is expected.
|
|
||||||
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
|
|
||||||
value_target (NestedKey): The input tensordict key where the target state value is expected.
|
|
||||||
Will be used for the underlying value estimator Defaults to ``"value_target"``.
|
|
||||||
value (NestedKey): The input tensordict key where the state value is expected.
|
|
||||||
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
|
||||||
sample_log_prob (NestedKey): The input tensordict key where the
|
|
||||||
sample log probability is expected. Defaults to ``"sample_log_prob"``.
|
|
||||||
action (NestedKey): The input tensordict key where the action is expected.
|
|
||||||
Defaults to ``"action"``.
|
|
||||||
reward (NestedKey): The input tensordict key where the reward is expected.
|
|
||||||
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
|
||||||
done (NestedKey): The key in the input TensorDict that indicates
|
|
||||||
whether a trajectory is done. Will be used for the underlying value estimator.
|
|
||||||
Defaults to ``"done"``.
|
|
||||||
terminated (NestedKey): The key in the input TensorDict that indicates
|
|
||||||
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
|
||||||
Defaults to ``"terminated"``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
advantage: NestedKey = "advantage"
|
|
||||||
value_target: NestedKey = "value_target"
|
|
||||||
value: NestedKey = "state_value"
|
|
||||||
sample_log_prob: NestedKey = "sample_log_prob"
|
|
||||||
action: NestedKey = "action"
|
|
||||||
reward: NestedKey = "reward"
|
|
||||||
done: NestedKey = "done"
|
|
||||||
terminated: NestedKey = "terminated"
|
|
||||||
|
|
||||||
default_keys = _AcceptedKeys()
|
|
||||||
default_value_estimator = ValueEstimators.GAE
|
|
||||||
|
|
||||||
|
|
||||||
actor_network: TensorDictModule
|
|
||||||
critic_network: TensorDictModule
|
|
||||||
actor_network_params: TensorDictParams
|
|
||||||
critic_network_params: TensorDictParams
|
|
||||||
target_actor_network_params: TensorDictParams
|
|
||||||
target_critic_network_params: TensorDictParams
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
actor_network: ProbabilisticTensorDictSequential | None = None,
|
actor_network: ProbabilisticTensorDictSequential,
|
||||||
critic_network: TensorDictModule | None = None,
|
old_actor_network: ProbabilisticTensorDictSequential,
|
||||||
trust_region_layer: any | None = None,
|
critic_network: TensorDictModule,
|
||||||
entropy_bonus: bool = True,
|
projection: any,
|
||||||
samples_mc_entropy: int = 1,
|
|
||||||
entropy_coef: float = 0.01,
|
entropy_coef: float = 0.01,
|
||||||
critic_coef: float = 1.0,
|
critic_coef: float = 1.0,
|
||||||
trust_region_coef: float = 10.0,
|
trust_region_coef: float = 10.0,
|
||||||
loss_critic_type: str = "smooth_l1",
|
|
||||||
normalize_advantage: bool = False,
|
normalize_advantage: bool = False,
|
||||||
gamma: float = None,
|
|
||||||
separate_losses: bool = False,
|
|
||||||
reduction: str = None,
|
|
||||||
clip_value: bool | float | None = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.trust_region_layer = trust_region_layer
|
super().__init__(
|
||||||
self.trust_region_coef = trust_region_coef
|
actor_network=actor_network,
|
||||||
|
critic_network=critic_network,
|
||||||
super(TRPLLoss, self).__init__(
|
|
||||||
actor_network,
|
|
||||||
critic_network,
|
|
||||||
entropy_bonus=entropy_bonus,
|
|
||||||
samples_mc_entropy=samples_mc_entropy,
|
|
||||||
entropy_coef=entropy_coef,
|
entropy_coef=entropy_coef,
|
||||||
critic_coef=critic_coef,
|
critic_coef=critic_coef,
|
||||||
loss_critic_type=loss_critic_type,
|
|
||||||
normalize_advantage=normalize_advantage,
|
normalize_advantage=normalize_advantage,
|
||||||
gamma=gamma,
|
|
||||||
separate_losses=separate_losses,
|
|
||||||
reduction=reduction,
|
|
||||||
clip_value=clip_value,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
self.old_actor_network = old_actor_network
|
||||||
|
self.projection = projection
|
||||||
|
self.trust_region_coef = trust_region_coef
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def out_keys(self):
|
def out_keys(self):
|
||||||
if self._out_keys is None:
|
if self._out_keys is None:
|
||||||
keys = ["loss_objective"]
|
keys = ["loss_objective", "tr_loss"]
|
||||||
if self.entropy_bonus:
|
if self.entropy_bonus:
|
||||||
keys.extend(["entropy", "loss_entropy"])
|
keys.extend(["entropy", "loss_entropy"])
|
||||||
if self.loss_critic:
|
if self.critic_coef:
|
||||||
keys.append("loss_critic")
|
keys.append("loss_critic")
|
||||||
keys.append("ESS")
|
keys.append("ESS")
|
||||||
self._out_keys = keys
|
self._out_keys = keys
|
||||||
@ -141,8 +81,12 @@ class TRPLLoss(PPOLoss):
|
|||||||
def out_keys(self, values):
|
def out_keys(self, values):
|
||||||
self._out_keys = values
|
self._out_keys = values
|
||||||
|
|
||||||
@dispatch
|
def _trust_region_loss(self, tensordict):
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
old_distribution = self.old_actor_network(tensordict)
|
||||||
|
raw_distribution = self.actor_network(tensordict)
|
||||||
|
return self.projection(self.actor_network, raw_distribution, old_distribution)
|
||||||
|
|
||||||
|
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
||||||
tensordict = tensordict.clone(False)
|
tensordict = tensordict.clone(False)
|
||||||
advantage = tensordict.get(self.tensor_keys.advantage, None)
|
advantage = tensordict.get(self.tensor_keys.advantage, None)
|
||||||
if advantage is None:
|
if advantage is None:
|
||||||
@ -159,11 +103,8 @@ class TRPLLoss(PPOLoss):
|
|||||||
|
|
||||||
log_weight, dist, kl_approx = self._log_weight(tensordict)
|
log_weight, dist, kl_approx = self._log_weight(tensordict)
|
||||||
trust_region_loss_unscaled = self._trust_region_loss(tensordict)
|
trust_region_loss_unscaled = self._trust_region_loss(tensordict)
|
||||||
# ESS for logging
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
|
|
||||||
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
|
|
||||||
# of the weights.
|
|
||||||
lw = log_weight.squeeze()
|
lw = log_weight.squeeze()
|
||||||
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
|
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
|
||||||
batch = log_weight.shape[0]
|
batch = log_weight.shape[0]
|
||||||
@ -194,8 +135,3 @@ class TRPLLoss(PPOLoss):
|
|||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
return td_out
|
return td_out
|
||||||
|
|
||||||
def _trust_region_loss(self, tensordict):
|
|
||||||
old_distribution =
|
|
||||||
raw_distribution =
|
|
||||||
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
|
Loading…
Reference in New Issue
Block a user