Compare commits
2 Commits
b34224f189
...
c6a12aa27b
Author | SHA1 | Date | |
---|---|---|---|
c6a12aa27b | |||
3931f5e31b |
114
fancy_rl/algos/trpl.py
Normal file
114
fancy_rl/algos/trpl.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
import torch
|
||||||
|
from torchrl.modules import ActorValueOperator, ProbabilisticActor
|
||||||
|
from torchrl.objectives.value.advantages import GAE
|
||||||
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
|
from fancy_rl.policy import Actor, Critic, SharedModule
|
||||||
|
from fancy_rl.objectives import TRPLLoss
|
||||||
|
|
||||||
|
class TRPL(OnPolicy):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_spec,
|
||||||
|
loggers=None,
|
||||||
|
actor_hidden_sizes=[64, 64],
|
||||||
|
critic_hidden_sizes=[64, 64],
|
||||||
|
actor_activation_fn="Tanh",
|
||||||
|
critic_activation_fn="Tanh",
|
||||||
|
shared_stem_sizes=[64],
|
||||||
|
proj_layer_type=None,
|
||||||
|
learning_rate=3e-4,
|
||||||
|
n_steps=2048,
|
||||||
|
batch_size=64,
|
||||||
|
n_epochs=10,
|
||||||
|
gamma=0.99,
|
||||||
|
gae_lambda=0.95,
|
||||||
|
total_timesteps=1e6,
|
||||||
|
eval_interval=2048,
|
||||||
|
eval_deterministic=True,
|
||||||
|
entropy_coef=0.01,
|
||||||
|
critic_coef=0.5,
|
||||||
|
trust_region_coef=10.0,
|
||||||
|
normalize_advantage=True,
|
||||||
|
device=None,
|
||||||
|
env_spec_eval=None,
|
||||||
|
eval_episodes=10,
|
||||||
|
):
|
||||||
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
self.trust_region_layer = None # from proj_layer_type
|
||||||
|
|
||||||
|
# Initialize environment to get observation and action space sizes
|
||||||
|
self.env_spec = env_spec
|
||||||
|
env = self.make_env()
|
||||||
|
obs_space = env.observation_space
|
||||||
|
act_space = env.action_space
|
||||||
|
|
||||||
|
# Define the shared, actor, and critic modules
|
||||||
|
self.shared_module = SharedModule(obs_space, shared_stem_sizes, actor_activation_fn, device)
|
||||||
|
self.raw_actor = Actor(self.shared_module, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||||
|
self.critic = Critic(self.shared_module, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
|
|
||||||
|
# Perfrom projection
|
||||||
|
self.actor = self.raw_actor # TODO: Project
|
||||||
|
|
||||||
|
# Combine into an ActorValueOperator
|
||||||
|
self.ac_module = ActorValueOperator(
|
||||||
|
self.shared_module,
|
||||||
|
self.actor,
|
||||||
|
self.critic
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define the policy as a ProbabilisticActor
|
||||||
|
policy = ProbabilisticActor(
|
||||||
|
module=self.ac_module.get_policy_operator(),
|
||||||
|
in_keys=["loc", "scale"],
|
||||||
|
out_keys=["action"],
|
||||||
|
distribution_class=torch.distributions.Normal,
|
||||||
|
return_log_prob=True
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizers = {
|
||||||
|
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||||
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.adv_module = GAE(
|
||||||
|
gamma=self.gamma,
|
||||||
|
lmbda=self.gae_lambda,
|
||||||
|
value_network=self.critic,
|
||||||
|
average_gae=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss_module = TRPLLoss(
|
||||||
|
actor_network=self.actor,
|
||||||
|
critic_network=self.critic,
|
||||||
|
trust_region_layer=self.trust_region_layer,
|
||||||
|
loss_critic_type='MSELoss',
|
||||||
|
entropy_coef=self.entropy_coef,
|
||||||
|
critic_coef=self.critic_coef,
|
||||||
|
trust_region_coef=self.trust_region_coef,
|
||||||
|
normalize_advantage=self.normalize_advantage,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
policy=policy,
|
||||||
|
env_spec=env_spec,
|
||||||
|
loggers=loggers,
|
||||||
|
optimizers=optimizers,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
n_steps=n_steps,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_epochs=n_epochs,
|
||||||
|
gamma=gamma,
|
||||||
|
gae_lambda=gae_lambda,
|
||||||
|
total_timesteps=total_timesteps,
|
||||||
|
eval_interval=eval_interval,
|
||||||
|
eval_deterministic=eval_deterministic,
|
||||||
|
entropy_coef=entropy_coef,
|
||||||
|
critic_coef=critic_coef,
|
||||||
|
normalize_advantage=normalize_advantage,
|
||||||
|
clip_range=clip_range,
|
||||||
|
device=device,
|
||||||
|
env_spec_eval=env_spec_eval,
|
||||||
|
eval_episodes=eval_episodes,
|
||||||
|
)
|
1
fancy_rl/objectives/__init__.py
Normal file
1
fancy_rl/objectives/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from fancy_rl.objectives.trpl import TRPLLoss
|
201
fancy_rl/objectives/trpl.py
Normal file
201
fancy_rl/objectives/trpl.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tensordict import TensorDict, TensorDictBase, TensorDictParams
|
||||||
|
from tensordict.nn import (
|
||||||
|
dispatch,
|
||||||
|
ProbabilisticTensorDictModule,
|
||||||
|
ProbabilisticTensorDictSequential,
|
||||||
|
TensorDictModule,
|
||||||
|
)
|
||||||
|
from tensordict.utils import NestedKey
|
||||||
|
from torch import distributions as d
|
||||||
|
|
||||||
|
from torchrl.objectives.common import LossModule
|
||||||
|
|
||||||
|
from torchrl.objectives.utils import (
|
||||||
|
_cache_values,
|
||||||
|
_clip_value_loss,
|
||||||
|
_GAMMA_LMBDA_DEPREC_ERROR,
|
||||||
|
_reduce,
|
||||||
|
default_value_kwargs,
|
||||||
|
distance_loss,
|
||||||
|
ValueEstimators,
|
||||||
|
)
|
||||||
|
from torchrl.objectives.value import (
|
||||||
|
GAE,
|
||||||
|
TD0Estimator,
|
||||||
|
TD1Estimator,
|
||||||
|
TDLambdaEstimator,
|
||||||
|
VTrace,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torchrl.objectives.ppo import PPOLoss
|
||||||
|
|
||||||
|
class TRPLLoss(PPOLoss):
|
||||||
|
@dataclass
|
||||||
|
class _AcceptedKeys:
|
||||||
|
"""Maintains default values for all configurable tensordict keys.
|
||||||
|
|
||||||
|
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
||||||
|
default values
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
advantage (NestedKey): The input tensordict key where the advantage is expected.
|
||||||
|
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
|
||||||
|
value_target (NestedKey): The input tensordict key where the target state value is expected.
|
||||||
|
Will be used for the underlying value estimator Defaults to ``"value_target"``.
|
||||||
|
value (NestedKey): The input tensordict key where the state value is expected.
|
||||||
|
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
||||||
|
sample_log_prob (NestedKey): The input tensordict key where the
|
||||||
|
sample log probability is expected. Defaults to ``"sample_log_prob"``.
|
||||||
|
action (NestedKey): The input tensordict key where the action is expected.
|
||||||
|
Defaults to ``"action"``.
|
||||||
|
reward (NestedKey): The input tensordict key where the reward is expected.
|
||||||
|
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
||||||
|
done (NestedKey): The key in the input TensorDict that indicates
|
||||||
|
whether a trajectory is done. Will be used for the underlying value estimator.
|
||||||
|
Defaults to ``"done"``.
|
||||||
|
terminated (NestedKey): The key in the input TensorDict that indicates
|
||||||
|
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
||||||
|
Defaults to ``"terminated"``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
advantage: NestedKey = "advantage"
|
||||||
|
value_target: NestedKey = "value_target"
|
||||||
|
value: NestedKey = "state_value"
|
||||||
|
sample_log_prob: NestedKey = "sample_log_prob"
|
||||||
|
action: NestedKey = "action"
|
||||||
|
reward: NestedKey = "reward"
|
||||||
|
done: NestedKey = "done"
|
||||||
|
terminated: NestedKey = "terminated"
|
||||||
|
|
||||||
|
default_keys = _AcceptedKeys()
|
||||||
|
default_value_estimator = ValueEstimators.GAE
|
||||||
|
|
||||||
|
|
||||||
|
actor_network: TensorDictModule
|
||||||
|
critic_network: TensorDictModule
|
||||||
|
actor_network_params: TensorDictParams
|
||||||
|
critic_network_params: TensorDictParams
|
||||||
|
target_actor_network_params: TensorDictParams
|
||||||
|
target_critic_network_params: TensorDictParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor_network: ProbabilisticTensorDictSequential | None = None,
|
||||||
|
critic_network: TensorDictModule | None = None,
|
||||||
|
trust_region_layer: any | None = None,
|
||||||
|
entropy_bonus: bool = True,
|
||||||
|
samples_mc_entropy: int = 1,
|
||||||
|
entropy_coef: float = 0.01,
|
||||||
|
critic_coef: float = 1.0,
|
||||||
|
trust_region_coef: float = 10.0,
|
||||||
|
loss_critic_type: str = "smooth_l1",
|
||||||
|
normalize_advantage: bool = False,
|
||||||
|
gamma: float = None,
|
||||||
|
separate_losses: bool = False,
|
||||||
|
reduction: str = None,
|
||||||
|
clip_value: bool | float | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.trust_region_layer = trust_region_layer
|
||||||
|
self.trust_region_coef = trust_region_coef
|
||||||
|
|
||||||
|
super(TRPLLoss, self).__init__(
|
||||||
|
actor_network,
|
||||||
|
critic_network,
|
||||||
|
entropy_bonus=entropy_bonus,
|
||||||
|
samples_mc_entropy=samples_mc_entropy,
|
||||||
|
entropy_coef=entropy_coef,
|
||||||
|
critic_coef=critic_coef,
|
||||||
|
loss_critic_type=loss_critic_type,
|
||||||
|
normalize_advantage=normalize_advantage,
|
||||||
|
gamma=gamma,
|
||||||
|
separate_losses=separate_losses,
|
||||||
|
reduction=reduction,
|
||||||
|
clip_value=clip_value,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_keys(self):
|
||||||
|
if self._out_keys is None:
|
||||||
|
keys = ["loss_objective"]
|
||||||
|
if self.entropy_bonus:
|
||||||
|
keys.extend(["entropy", "loss_entropy"])
|
||||||
|
if self.loss_critic:
|
||||||
|
keys.append("loss_critic")
|
||||||
|
keys.append("ESS")
|
||||||
|
self._out_keys = keys
|
||||||
|
return self._out_keys
|
||||||
|
|
||||||
|
@out_keys.setter
|
||||||
|
def out_keys(self, values):
|
||||||
|
self._out_keys = values
|
||||||
|
|
||||||
|
@dispatch
|
||||||
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||||
|
tensordict = tensordict.clone(False)
|
||||||
|
advantage = tensordict.get(self.tensor_keys.advantage, None)
|
||||||
|
if advantage is None:
|
||||||
|
self.value_estimator(
|
||||||
|
tensordict,
|
||||||
|
params=self._cached_critic_network_params_detached,
|
||||||
|
target_params=self.target_critic_network_params,
|
||||||
|
)
|
||||||
|
advantage = tensordict.get(self.tensor_keys.advantage)
|
||||||
|
if self.normalize_advantage and advantage.numel() > 1:
|
||||||
|
loc = advantage.mean()
|
||||||
|
scale = advantage.std().clamp_min(1e-6)
|
||||||
|
advantage = (advantage - loc) / scale
|
||||||
|
|
||||||
|
log_weight, dist, kl_approx = self._log_weight(tensordict)
|
||||||
|
trust_region_loss_unscaled = self._trust_region_loss(tensordict)
|
||||||
|
# ESS for logging
|
||||||
|
with torch.no_grad():
|
||||||
|
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
|
||||||
|
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
|
||||||
|
# of the weights.
|
||||||
|
lw = log_weight.squeeze()
|
||||||
|
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
|
||||||
|
batch = log_weight.shape[0]
|
||||||
|
|
||||||
|
surrogate_gain = log_weight.exp() * advantage
|
||||||
|
trust_region_loss = trust_region_loss_unscaled * self.trust_region_coef
|
||||||
|
|
||||||
|
loss = -surrogate_gain + trust_region_loss
|
||||||
|
td_out = TensorDict({"loss_objective": loss}, batch_size=[])
|
||||||
|
td_out.set("tr_loss", trust_region_loss)
|
||||||
|
|
||||||
|
if self.entropy_bonus:
|
||||||
|
entropy = self.get_entropy_bonus(dist)
|
||||||
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
||||||
|
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
|
||||||
|
td_out.set("loss_entropy", -self.entropy_coef * entropy)
|
||||||
|
if self.critic_coef:
|
||||||
|
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
|
||||||
|
td_out.set("loss_critic", loss_critic)
|
||||||
|
if value_clip_fraction is not None:
|
||||||
|
td_out.set("value_clip_fraction", value_clip_fraction)
|
||||||
|
|
||||||
|
td_out.set("ESS", _reduce(ess, self.reduction) / batch)
|
||||||
|
td_out = td_out.named_apply(
|
||||||
|
lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
|
||||||
|
if name.startswith("loss_")
|
||||||
|
else value,
|
||||||
|
batch_size=[],
|
||||||
|
)
|
||||||
|
return td_out
|
||||||
|
|
||||||
|
def _trust_region_loss(self, tensordict):
|
||||||
|
old_distribution =
|
||||||
|
raw_distribution =
|
||||||
|
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
Loading…
Reference in New Issue
Block a user