Compare commits
No commits in common. "9c55b6a110a179caa594129e9f963f0b4e622be7" and "4240f611ac459670f9a89fb9a750d7cb04fe9aa2" have entirely different histories.
9c55b6a110
...
4240f611ac
@ -4,7 +4,6 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from fancy_rl.algos import PPO, TRPL, VLEARN
|
||||
from fancy_rl.projections import get_projection
|
||||
from fancy_rl.algos import PPO
|
||||
|
||||
__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]
|
||||
__all__ = ["PPO"]
|
@ -1,3 +1 @@
|
||||
from fancy_rl.algos.ppo import PPO
|
||||
from fancy_rl.algos.trpl import TRPL
|
||||
from fancy_rl.algos.vlearn import VLEARN
|
||||
from fancy_rl.algos.ppo import PPO
|
@ -15,17 +15,23 @@ class OnPolicy(Algo):
|
||||
env_spec,
|
||||
optimizers,
|
||||
loggers=None,
|
||||
actor_hidden_sizes=[64, 64],
|
||||
critic_hidden_sizes=[64, 64],
|
||||
actor_activation_fn="Tanh",
|
||||
critic_activation_fn="Tanh",
|
||||
learning_rate=3e-4,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
total_timesteps=1e6,
|
||||
eval_interval=2048,
|
||||
eval_deterministic=True,
|
||||
entropy_coef=0.01,
|
||||
critic_coef=0.5,
|
||||
normalize_advantage=True,
|
||||
clip_range=0.2,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
device=None,
|
||||
@ -71,25 +77,15 @@ class OnPolicy(Algo):
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
def pre_process_batch(self, batch):
|
||||
return batch
|
||||
|
||||
def post_process_batch(self, batch):
|
||||
pass
|
||||
|
||||
def train_step(self, batch):
|
||||
batch = self.pre_process_batch(batch)
|
||||
|
||||
for optimizer in self.optimizers.values():
|
||||
optimizer.zero_grad()
|
||||
losses = self.loss_module(batch)
|
||||
loss = sum(losses.values()) # Sum all losses
|
||||
loss = losses['loss_objective'] + losses["loss_entropy"] + losses["loss_critic"]
|
||||
loss.backward()
|
||||
for optimizer in self.optimizers.values():
|
||||
optimizer.step()
|
||||
|
||||
self.post_process_batch(batch)
|
||||
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
|
@ -4,7 +4,6 @@ from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection # Updated import
|
||||
|
||||
class PPO(OnPolicy):
|
||||
def __init__(
|
||||
|
@ -1,16 +1,9 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, Any, Optional
|
||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||
from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.collectors import SyncDataCollector
|
||||
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
|
||||
from torchrl.objectives.value import GAE
|
||||
from torchrl.modules import ProbabilisticActor
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection, BaseProjection
|
||||
from fancy_rl.objectives import TRPLLoss
|
||||
from copy import deepcopy
|
||||
|
||||
class TRPL(OnPolicy):
|
||||
def __init__(
|
||||
@ -21,21 +14,19 @@ class TRPL(OnPolicy):
|
||||
critic_hidden_sizes=[64, 64],
|
||||
actor_activation_fn="Tanh",
|
||||
critic_activation_fn="Tanh",
|
||||
proj_layer_type=None,
|
||||
learning_rate=3e-4,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
projection_class="identity_projection",
|
||||
trust_region_coef=10.0,
|
||||
trust_region_bound_mean=0.1,
|
||||
trust_region_bound_cov=0.001,
|
||||
total_timesteps=1e6,
|
||||
eval_interval=2048,
|
||||
eval_deterministic=True,
|
||||
entropy_coef=0.01,
|
||||
critic_coef=0.5,
|
||||
trust_region_coef=10.0,
|
||||
normalize_advantage=False,
|
||||
device=None,
|
||||
env_spec_eval=None,
|
||||
@ -44,6 +35,9 @@ class TRPL(OnPolicy):
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
|
||||
self.trust_region_layer = None # TODO: from proj_layer_type
|
||||
self.trust_region_coef = trust_region_coef
|
||||
|
||||
# Initialize environment to get observation and action space sizes
|
||||
self.env_spec = env_spec
|
||||
env = self.make_env()
|
||||
@ -52,40 +46,14 @@ class TRPL(OnPolicy):
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||
|
||||
# Handle projection_class
|
||||
if isinstance(projection_class, str):
|
||||
projection_class = get_projection(projection_class)
|
||||
elif not issubclass(projection_class, BaseProjection):
|
||||
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
||||
|
||||
self.projection = projection_class(
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["loc", "scale"],
|
||||
trust_region_bound_mean=trust_region_bound_mean,
|
||||
trust_region_bound_cov=trust_region_bound_cov
|
||||
)
|
||||
|
||||
self.actor = ProbabilisticActor(
|
||||
raw_actor = ProbabilisticActor(
|
||||
module=actor_net,
|
||||
in_keys=["observation"],
|
||||
out_keys=["loc", "scale"],
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["action"],
|
||||
distribution_class=torch.distributions.Normal,
|
||||
return_log_prob=True
|
||||
)
|
||||
self.old_actor = deepcopy(self.actor)
|
||||
|
||||
self.trust_region_coef = trust_region_coef
|
||||
self.loss_module = TRPLLoss(
|
||||
actor_network=self.actor,
|
||||
old_actor_network=self.old_actor,
|
||||
critic_network=self.critic,
|
||||
projection=self.projection,
|
||||
entropy_coef=entropy_coef,
|
||||
critic_coef=critic_coef,
|
||||
trust_region_coef=trust_region_coef,
|
||||
normalize_advantage=normalize_advantage,
|
||||
)
|
||||
self.actor = raw_actor # TODO: Proj here
|
||||
|
||||
optimizers = {
|
||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||
@ -111,6 +79,7 @@ class TRPL(OnPolicy):
|
||||
env_spec_eval=env_spec_eval,
|
||||
eval_episodes=eval_episodes,
|
||||
)
|
||||
|
||||
self.adv_module = GAE(
|
||||
gamma=self.gamma,
|
||||
lmbda=gae_lambda,
|
||||
@ -118,24 +87,13 @@ class TRPL(OnPolicy):
|
||||
average_gae=False,
|
||||
)
|
||||
|
||||
def update_old_policy(self):
|
||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
||||
|
||||
def project_policy(self, obs):
|
||||
with torch.no_grad():
|
||||
old_dist = self.old_actor(obs)
|
||||
new_dist = self.actor(obs)
|
||||
projected_params = self.projection.project(new_dist, old_dist)
|
||||
return projected_params
|
||||
|
||||
def pre_update(self, tensordict):
|
||||
obs = tensordict["observation"]
|
||||
projected_dist = self.project_policy(obs)
|
||||
|
||||
# Update tensordict with projected distribution parameters
|
||||
tensordict["projected_loc"] = projected_dist[0]
|
||||
tensordict["projected_scale"] = projected_dist[1]
|
||||
return tensordict
|
||||
|
||||
def post_update(self):
|
||||
self.update_old_policy()
|
||||
self.loss_module = TRPLLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
trust_region_layer=self.trust_region_layer,
|
||||
loss_critic_type='l2',
|
||||
entropy_coef=self.entropy_coef,
|
||||
critic_coef=self.critic_coef,
|
||||
trust_region_coef=self.trust_region_coef,
|
||||
normalize_advantage=self.normalize_advantage,
|
||||
)
|
||||
|
@ -1,114 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, Any, Optional
|
||||
from torchrl.modules import ProbabilisticActor, ValueOperator
|
||||
from torchrl.collectors import SyncDataCollector
|
||||
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
|
||||
|
||||
from fancy_rl.utils import get_env, get_actor, get_critic
|
||||
from fancy_rl.modules.vlearn_loss import VLEARNLoss
|
||||
from fancy_rl.modules.projection import get_vlearn_projection
|
||||
from fancy_rl.modules.squashed_normal import get_squashed_normal
|
||||
|
||||
class VLEARN:
|
||||
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):
|
||||
self.device = torch.device(device)
|
||||
self.env = get_env(env_id)
|
||||
|
||||
self.projection = get_vlearn_projection(**kwargs.get("projection", {}))
|
||||
|
||||
actor = get_actor(self.env, **kwargs.get("actor", {}))
|
||||
self.actor = ProbabilisticActor(
|
||||
actor,
|
||||
in_keys=["observation"],
|
||||
out_keys=["loc", "scale"],
|
||||
distribution_class=get_squashed_normal(),
|
||||
return_log_prob=True
|
||||
).to(self.device)
|
||||
self.old_actor = self.actor.clone()
|
||||
|
||||
self.critic = ValueOperator(
|
||||
module=get_critic(self.env, **kwargs.get("critic", {})),
|
||||
in_keys=["observation"]
|
||||
).to(self.device)
|
||||
|
||||
self.collector = SyncDataCollector(
|
||||
self.env,
|
||||
self.actor,
|
||||
frames_per_batch=kwargs.get("frames_per_batch", 1000),
|
||||
total_frames=kwargs.get("total_frames", -1),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.replay_buffer = TensorDictReplayBuffer(
|
||||
storage=LazyMemmapStorage(kwargs.get("buffer_size", 100000)),
|
||||
batch_size=kwargs.get("batch_size", 256),
|
||||
)
|
||||
|
||||
self.loss_module = VLEARNLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
old_actor_network=self.old_actor,
|
||||
projection=self.projection,
|
||||
**kwargs.get("loss", {})
|
||||
)
|
||||
|
||||
self.optimizers = nn.ModuleDict({
|
||||
"policy": torch.optim.Adam(self.actor.parameters(), lr=kwargs.get("lr_policy", 3e-4)),
|
||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=kwargs.get("lr_critic", 3e-4))
|
||||
})
|
||||
|
||||
self.update_policy_interval = kwargs.get("update_policy_interval", 1)
|
||||
self.update_critic_interval = kwargs.get("update_critic_interval", 1)
|
||||
self.target_update_interval = kwargs.get("target_update_interval", 1)
|
||||
self.polyak_weight_critic = kwargs.get("polyak_weight_critic", 0.995)
|
||||
|
||||
def train(self, num_iterations: int = 1000) -> None:
|
||||
for i in range(num_iterations):
|
||||
data = next(self.collector)
|
||||
self.replay_buffer.extend(data)
|
||||
|
||||
batch = self.replay_buffer.sample().to(self.device)
|
||||
loss_dict = self.loss_module(batch)
|
||||
|
||||
if i % self.update_policy_interval == 0:
|
||||
self.optimizers["policy"].zero_grad()
|
||||
loss_dict["policy_loss"].backward()
|
||||
self.optimizers["policy"].step()
|
||||
|
||||
if i % self.update_critic_interval == 0:
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_dict["critic_loss"].backward()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if i % self.target_update_interval == 0:
|
||||
self.critic.update_target_params(self.polyak_weight_critic)
|
||||
|
||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
||||
self.collector.update_policy_weights_()
|
||||
|
||||
if i % 100 == 0:
|
||||
eval_reward = self.eval()
|
||||
print(f"Iteration {i}, Eval reward: {eval_reward}")
|
||||
|
||||
def eval(self, num_episodes: int = 10) -> float:
|
||||
total_reward = 0
|
||||
for _ in range(num_episodes):
|
||||
td = self.env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
with torch.no_grad():
|
||||
action = self.actor(td.to(self.device))["action"]
|
||||
td = self.env.step(action)
|
||||
total_reward += td["reward"].item()
|
||||
done = td["done"].item()
|
||||
return total_reward / num_episodes
|
||||
|
||||
def save_policy(self, path: str) -> None:
|
||||
torch.save(self.actor.state_dict(), f"{path}/actor.pth")
|
||||
torch.save(self.critic.state_dict(), f"{path}/critic.pth")
|
||||
|
||||
def load_policy(self, path: str) -> None:
|
||||
self.actor.load_state_dict(torch.load(f"{path}/actor.pth"))
|
||||
self.critic.load_state_dict(torch.load(f"{path}/critic.pth"))
|
||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
@ -38,40 +38,100 @@ from torchrl.objectives.value import (
|
||||
)
|
||||
|
||||
from torchrl.objectives.ppo import PPOLoss
|
||||
from fancy_rl.projections import get_projection
|
||||
|
||||
class TRPLLoss(PPOLoss):
|
||||
@dataclass
|
||||
class _AcceptedKeys:
|
||||
"""Maintains default values for all configurable tensordict keys.
|
||||
|
||||
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
||||
default values
|
||||
|
||||
Attributes:
|
||||
advantage (NestedKey): The input tensordict key where the advantage is expected.
|
||||
Will be used for the underlying value estimator. Defaults to ``"advantage"``.
|
||||
value_target (NestedKey): The input tensordict key where the target state value is expected.
|
||||
Will be used for the underlying value estimator Defaults to ``"value_target"``.
|
||||
value (NestedKey): The input tensordict key where the state value is expected.
|
||||
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
|
||||
sample_log_prob (NestedKey): The input tensordict key where the
|
||||
sample log probability is expected. Defaults to ``"sample_log_prob"``.
|
||||
action (NestedKey): The input tensordict key where the action is expected.
|
||||
Defaults to ``"action"``.
|
||||
reward (NestedKey): The input tensordict key where the reward is expected.
|
||||
Will be used for the underlying value estimator. Defaults to ``"reward"``.
|
||||
done (NestedKey): The key in the input TensorDict that indicates
|
||||
whether a trajectory is done. Will be used for the underlying value estimator.
|
||||
Defaults to ``"done"``.
|
||||
terminated (NestedKey): The key in the input TensorDict that indicates
|
||||
whether a trajectory is terminated. Will be used for the underlying value estimator.
|
||||
Defaults to ``"terminated"``.
|
||||
"""
|
||||
|
||||
advantage: NestedKey = "advantage"
|
||||
value_target: NestedKey = "value_target"
|
||||
value: NestedKey = "state_value"
|
||||
sample_log_prob: NestedKey = "sample_log_prob"
|
||||
action: NestedKey = "action"
|
||||
reward: NestedKey = "reward"
|
||||
done: NestedKey = "done"
|
||||
terminated: NestedKey = "terminated"
|
||||
|
||||
default_keys = _AcceptedKeys()
|
||||
default_value_estimator = ValueEstimators.GAE
|
||||
|
||||
|
||||
actor_network: TensorDictModule
|
||||
critic_network: TensorDictModule
|
||||
actor_network_params: TensorDictParams
|
||||
critic_network_params: TensorDictParams
|
||||
target_actor_network_params: TensorDictParams
|
||||
target_critic_network_params: TensorDictParams
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_network: ProbabilisticTensorDictSequential,
|
||||
old_actor_network: ProbabilisticTensorDictSequential,
|
||||
critic_network: TensorDictModule,
|
||||
projection: any,
|
||||
actor_network: ProbabilisticTensorDictSequential | None = None,
|
||||
critic_network: TensorDictModule | None = None,
|
||||
trust_region_layer: any | None = None,
|
||||
entropy_bonus: bool = True,
|
||||
samples_mc_entropy: int = 1,
|
||||
entropy_coef: float = 0.01,
|
||||
critic_coef: float = 1.0,
|
||||
trust_region_coef: float = 10.0,
|
||||
loss_critic_type: str = "smooth_l1",
|
||||
normalize_advantage: bool = False,
|
||||
gamma: float = None,
|
||||
separate_losses: bool = False,
|
||||
reduction: str = None,
|
||||
clip_value: bool | float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
actor_network=actor_network,
|
||||
critic_network=critic_network,
|
||||
self.trust_region_layer = trust_region_layer
|
||||
self.trust_region_coef = trust_region_coef
|
||||
|
||||
super(TRPLLoss, self).__init__(
|
||||
actor_network,
|
||||
critic_network,
|
||||
entropy_bonus=entropy_bonus,
|
||||
samples_mc_entropy=samples_mc_entropy,
|
||||
entropy_coef=entropy_coef,
|
||||
critic_coef=critic_coef,
|
||||
loss_critic_type=loss_critic_type,
|
||||
normalize_advantage=normalize_advantage,
|
||||
gamma=gamma,
|
||||
separate_losses=separate_losses,
|
||||
reduction=reduction,
|
||||
clip_value=clip_value,
|
||||
**kwargs,
|
||||
)
|
||||
self.old_actor_network = old_actor_network
|
||||
self.projection = projection
|
||||
self.trust_region_coef = trust_region_coef
|
||||
|
||||
@property
|
||||
def out_keys(self):
|
||||
if self._out_keys is None:
|
||||
keys = ["loss_objective", "tr_loss"]
|
||||
keys = ["loss_objective"]
|
||||
if self.entropy_bonus:
|
||||
keys.extend(["entropy", "loss_entropy"])
|
||||
if self.critic_coef:
|
||||
if self.loss_critic:
|
||||
keys.append("loss_critic")
|
||||
keys.append("ESS")
|
||||
self._out_keys = keys
|
||||
@ -81,12 +141,8 @@ class TRPLLoss(PPOLoss):
|
||||
def out_keys(self, values):
|
||||
self._out_keys = values
|
||||
|
||||
def _trust_region_loss(self, tensordict):
|
||||
old_distribution = self.old_actor_network(tensordict)
|
||||
raw_distribution = self.actor_network(tensordict)
|
||||
return self.projection(self.actor_network, raw_distribution, old_distribution)
|
||||
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
||||
@dispatch
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
tensordict = tensordict.clone(False)
|
||||
advantage = tensordict.get(self.tensor_keys.advantage, None)
|
||||
if advantage is None:
|
||||
@ -103,8 +159,11 @@ class TRPLLoss(PPOLoss):
|
||||
|
||||
log_weight, dist, kl_approx = self._log_weight(tensordict)
|
||||
trust_region_loss_unscaled = self._trust_region_loss(tensordict)
|
||||
|
||||
# ESS for logging
|
||||
with torch.no_grad():
|
||||
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
|
||||
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
|
||||
# of the weights.
|
||||
lw = log_weight.squeeze()
|
||||
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
|
||||
batch = log_weight.shape[0]
|
||||
@ -135,3 +194,8 @@ class TRPLLoss(PPOLoss):
|
||||
batch_size=[],
|
||||
)
|
||||
return td_out
|
||||
|
||||
def _trust_region_loss(self, tensordict):
|
||||
old_distribution =
|
||||
raw_distribution =
|
||||
return self.policy_projection.get_trust_region_loss(raw_distribution, old_distribution)
|
@ -1,106 +0,0 @@
|
||||
import torch
|
||||
from torchrl.objectives import LossModule
|
||||
from torch.distributions import Normal
|
||||
|
||||
class VLEARNLoss(LossModule):
|
||||
def __init__(
|
||||
self,
|
||||
actor_network,
|
||||
critic_network,
|
||||
old_actor_network,
|
||||
gamma=0.99,
|
||||
lmbda=0.95,
|
||||
entropy_coef=0.01,
|
||||
critic_coef=0.5,
|
||||
normalize_advantage=True,
|
||||
eps=1e-8,
|
||||
delta=0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.actor_network = actor_network
|
||||
self.critic_network = critic_network
|
||||
self.old_actor_network = old_actor_network
|
||||
self.gamma = gamma
|
||||
self.lmbda = lmbda
|
||||
self.entropy_coef = entropy_coef
|
||||
self.critic_coef = critic_coef
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.eps = eps
|
||||
self.delta = delta
|
||||
|
||||
def forward(self, tensordict):
|
||||
# Compute returns and advantages
|
||||
with torch.no_grad():
|
||||
returns = self.compute_returns(tensordict)
|
||||
values = self.critic_network(tensordict)["state_value"]
|
||||
advantages = returns - values
|
||||
if self.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
# Compute actor loss
|
||||
new_td = self.actor_network(tensordict)
|
||||
old_td = self.old_actor_network(tensordict)
|
||||
|
||||
new_dist = Normal(new_td["loc"], new_td["scale"])
|
||||
old_dist = Normal(old_td["loc"], old_td["scale"])
|
||||
|
||||
new_log_prob = new_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
old_log_prob = old_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
|
||||
ratio = torch.exp(new_log_prob - old_log_prob)
|
||||
|
||||
# Compute projection
|
||||
kl = torch.distributions.kl.kl_divergence(new_dist, old_dist).sum(-1)
|
||||
alpha = torch.where(kl > self.delta,
|
||||
torch.sqrt(self.delta / (kl + self.eps)),
|
||||
torch.ones_like(kl))
|
||||
proj_loc = alpha.unsqueeze(-1) * new_td["loc"] + (1 - alpha.unsqueeze(-1)) * old_td["loc"]
|
||||
proj_scale = torch.sqrt(alpha.unsqueeze(-1)**2 * new_td["scale"]**2 + (1 - alpha.unsqueeze(-1))**2 * old_td["scale"]**2)
|
||||
proj_dist = Normal(proj_loc, proj_scale)
|
||||
|
||||
proj_log_prob = proj_dist.log_prob(tensordict["action"]).sum(-1)
|
||||
proj_ratio = torch.exp(proj_log_prob - old_log_prob)
|
||||
|
||||
policy_loss = -torch.min(
|
||||
ratio * advantages,
|
||||
proj_ratio * advantages
|
||||
).mean()
|
||||
|
||||
# Compute critic loss
|
||||
value_pred = self.critic_network(tensordict)["state_value"]
|
||||
critic_loss = 0.5 * (returns - value_pred).pow(2).mean()
|
||||
|
||||
# Compute entropy loss
|
||||
entropy_loss = -self.entropy_coef * new_dist.entropy().mean()
|
||||
|
||||
# Combine losses
|
||||
loss = policy_loss + self.critic_coef * critic_loss + entropy_loss
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"policy_loss": policy_loss,
|
||||
"critic_loss": critic_loss,
|
||||
"entropy_loss": entropy_loss,
|
||||
}
|
||||
|
||||
def compute_returns(self, tensordict):
|
||||
rewards = tensordict["reward"]
|
||||
dones = tensordict["done"]
|
||||
values = self.critic_network(tensordict)["state_value"]
|
||||
|
||||
returns = torch.zeros_like(rewards)
|
||||
advantages = torch.zeros_like(rewards)
|
||||
last_gae_lam = 0
|
||||
|
||||
for t in reversed(range(len(rewards))):
|
||||
if t == len(rewards) - 1:
|
||||
next_value = 0
|
||||
else:
|
||||
next_value = values[t + 1]
|
||||
|
||||
delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
|
||||
advantages[t] = last_gae_lam = delta + self.gamma * self.lmbda * (1 - dones[t]) * last_gae_lam
|
||||
|
||||
returns = advantages + values
|
||||
|
||||
return returns
|
@ -1,20 +1,6 @@
|
||||
from .base_projection import BaseProjection
|
||||
from .identity_projection import IdentityProjection
|
||||
from .kl_projection import KLProjection
|
||||
from .wasserstein_projection import WassersteinProjection
|
||||
from .frobenius_projection import FrobeniusProjection
|
||||
|
||||
def get_projection(projection_name: str):
|
||||
projections = {
|
||||
"identity_projection": IdentityProjection,
|
||||
"kl_projection": KLProjection,
|
||||
"wasserstein_projection": WassersteinProjection,
|
||||
"frobenius_projection": FrobeniusProjection,
|
||||
}
|
||||
|
||||
projection = projections.get(projection_name.lower())
|
||||
if projection is None:
|
||||
raise ValueError(f"Unknown projection: {projection_name}")
|
||||
return projection
|
||||
|
||||
__all__ = ["BaseProjection", "IdentityProjection", "KLProjection", "WassersteinProjection", "FrobeniusProjection", "get_projection"]
|
||||
try:
|
||||
import cpp_projection
|
||||
except ModuleNotFoundError:
|
||||
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
|
||||
else:
|
||||
from .kl_projection_layer import KLProjectionLayer
|
@ -1,16 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
class BaseProjection(ABC, torch.nn.Module):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = out_keys
|
||||
|
||||
@abstractmethod
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
|
||||
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
return self.project(policy_params, old_policy_params)
|
@ -1,67 +0,0 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
|
||||
class FrobeniusProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
|
||||
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||
|
||||
proj_chol = torch.linalg.cholesky(proj_cov)
|
||||
return {"loc": proj_mean, "scale_tril": proj_chol}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
|
||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
||||
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
|
||||
|
||||
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
||||
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
||||
|
||||
return (mean_diff + cov_diff).mean() * self.trust_region_coeff
|
||||
|
||||
def _gaussian_frobenius(self, p, q):
|
||||
mean, cov = p
|
||||
old_mean, old_cov = q
|
||||
|
||||
if self.scale_prec:
|
||||
prec_old = torch.inverse(old_cov)
|
||||
mean_part = torch.sum(torch.matmul(mean - old_mean, prec_old) * (mean - old_mean), dim=-1)
|
||||
cov_part = torch.sum(prec_old * cov, dim=(-2, -1)) - torch.logdet(torch.matmul(prec_old, cov)) - mean.shape[-1]
|
||||
else:
|
||||
mean_part = torch.sum(torch.square(mean - old_mean), dim=-1)
|
||||
cov_part = torch.sum(torch.square(cov - old_cov), dim=(-2, -1))
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = mean - old_mean
|
||||
norm = torch.sqrt(mean_part)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||
|
||||
def _cov_projection(self, cov: torch.Tensor, old_cov: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
batch_shape = cov.shape[:-2]
|
||||
cov_mask = cov_part > self.cov_bound
|
||||
|
||||
eta = torch.ones(batch_shape, dtype=cov.dtype, device=cov.device)
|
||||
eta[cov_mask] = torch.sqrt(cov_part[cov_mask] / self.cov_bound) - 1.
|
||||
eta = torch.max(-eta, eta)
|
||||
|
||||
new_cov = (cov + torch.einsum('i,ijk->ijk', eta, old_cov)) / (1. + eta + 1e-16)[..., None, None]
|
||||
proj_cov = torch.where(cov_mask[..., None, None], new_cov, cov)
|
||||
|
||||
return proj_cov
|
@ -1,13 +0,0 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
|
||||
class IdentityProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
return policy_params
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
return torch.tensor(0.0, device=next(iter(policy_params.values())).device)
|
@ -1,199 +0,0 @@
|
||||
import torch
|
||||
import cpp_projection
|
||||
import numpy as np
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict, Tuple, Any
|
||||
|
||||
MAX_EVAL = 1000
|
||||
|
||||
def get_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
|
||||
class KLProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.is_diag = is_diag
|
||||
self.contextual_std = contextual_std
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
|
||||
|
||||
if not self.contextual_std:
|
||||
std = std[:1]
|
||||
old_std = old_std[:1]
|
||||
cov_part = cov_part[:1]
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_std = self._cov_projection(std, old_std, cov_part)
|
||||
|
||||
if not self.contextual_std:
|
||||
proj_std = proj_std.expand(mean.shape[0], -1, -1)
|
||||
|
||||
return {"loc": proj_mean, "scale_tril": proj_std}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
|
||||
return kl.mean() * self.trust_region_coeff
|
||||
|
||||
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, std = p
|
||||
mean_other, std_other = q
|
||||
k = mean.shape[-1]
|
||||
|
||||
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
|
||||
|
||||
det_term = self._log_determinant(std)
|
||||
det_term_other = self._log_determinant(std_other)
|
||||
|
||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
|
||||
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
||||
|
||||
return maha_part, cov_part
|
||||
|
||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||
diff = x - y
|
||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
|
||||
|
||||
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||
|
||||
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x.pow(2), dim=(-2, -1))
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
||||
|
||||
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
cov = torch.matmul(std, std.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
|
||||
|
||||
if self.is_diag:
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_std = torch.zeros_like(std)
|
||||
proj_std[~mask] = std[~mask]
|
||||
try:
|
||||
if mask.any():
|
||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
|
||||
old_cov.diagonal(dim1=-2, dim2=-1),
|
||||
self.cov_bound)
|
||||
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
||||
if is_invalid.any():
|
||||
proj_std[is_invalid] = old_std[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
|
||||
except Exception as e:
|
||||
proj_std = old_std
|
||||
else:
|
||||
try:
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_std = torch.zeros_like(std)
|
||||
proj_std[~mask] = std[~mask]
|
||||
if mask.any():
|
||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
|
||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||
if is_invalid.any():
|
||||
proj_std[is_invalid] = old_std[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||
failed_mask = failed_mask.bool()
|
||||
if failed_mask.any():
|
||||
proj_std[failed_mask] = old_std[failed_mask]
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error('Projection failed, taking old cholesky for projection.')
|
||||
print("Projection failed, taking old cholesky for projection.")
|
||||
proj_std = old_std
|
||||
raise e
|
||||
|
||||
return proj_std
|
||||
|
||||
|
||||
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
||||
projection_op = None
|
||||
|
||||
@staticmethod
|
||||
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
|
||||
if not KLProjectionGradFunctionCovOnly.projection_op:
|
||||
KLProjectionGradFunctionCovOnly.projection_op = \
|
||||
cpp_projection.BatchedCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
|
||||
return KLProjectionGradFunctionCovOnly.projection_op
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
cov, chol, old_chol, eps_cov = args
|
||||
|
||||
batch_shape = cov.shape[0]
|
||||
dim = cov.shape[-1]
|
||||
|
||||
cov_np = get_numpy(cov)
|
||||
chol_np = get_numpy(chol)
|
||||
old_chol_np = get_numpy(old_chol)
|
||||
eps = get_numpy(eps_cov) * np.ones(batch_shape)
|
||||
|
||||
p_op = KLProjectionGradFunctionCovOnly.get_projection_op(batch_shape, dim)
|
||||
ctx.proj = p_op
|
||||
|
||||
proj_std = p_op.forward(eps, old_chol_np, chol_np, cov_np)
|
||||
|
||||
return cov.new(proj_std)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
projection_op = ctx.proj
|
||||
d_cov, = grad_outputs
|
||||
|
||||
d_cov_np = get_numpy(d_cov)
|
||||
d_cov_np = np.atleast_2d(d_cov_np)
|
||||
|
||||
df_stds = projection_op.backward(d_cov_np)
|
||||
df_stds = np.atleast_2d(df_stds)
|
||||
|
||||
df_stds = d_cov.new(df_stds)
|
||||
|
||||
return df_stds, None, None, None
|
||||
|
||||
|
||||
class KLProjectionGradFunctionDiagCovOnly(torch.autograd.Function):
|
||||
projection_op = None
|
||||
|
||||
@staticmethod
|
||||
def get_projection_op(batch_shape, dim, max_eval=MAX_EVAL):
|
||||
if not KLProjectionGradFunctionDiagCovOnly.projection_op:
|
||||
KLProjectionGradFunctionDiagCovOnly.projection_op = \
|
||||
cpp_projection.BatchedDiagCovOnlyProjection(batch_shape, dim, max_eval=max_eval)
|
||||
return KLProjectionGradFunctionDiagCovOnly.projection_op
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
cov, old_cov, eps_cov = args
|
||||
|
||||
batch_shape = cov.shape[0]
|
||||
dim = cov.shape[-1]
|
||||
|
||||
cov_np = get_numpy(cov)
|
||||
old_cov_np = get_numpy(old_cov)
|
||||
eps = get_numpy(eps_cov) * np.ones(batch_shape)
|
||||
|
||||
p_op = KLProjectionGradFunctionDiagCovOnly.get_projection_op(batch_shape, dim)
|
||||
ctx.proj = p_op
|
||||
|
||||
proj_std = p_op.forward(eps, old_cov_np, cov_np)
|
||||
|
||||
return cov.new(proj_std)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
projection_op = ctx.proj
|
||||
d_std, = grad_outputs
|
||||
|
||||
d_cov_np = get_numpy(d_std)
|
||||
d_cov_np = np.atleast_2d(d_cov_np)
|
||||
df_stds = projection_op.backward(d_cov_np)
|
||||
df_stds = np.atleast_2d(df_stds)
|
||||
|
||||
return d_std.new(df_stds), None, None
|
@ -1,56 +0,0 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from typing import Dict, Tuple
|
||||
|
||||
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
||||
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, sqrt = p
|
||||
mean_other, sqrt_other = q
|
||||
|
||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||
|
||||
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
|
||||
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
|
||||
|
||||
if scale_prec:
|
||||
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
|
||||
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
|
||||
c = sqrt_inv_other @ cov @ sqrt_inv_other
|
||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
|
||||
else:
|
||||
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
class WassersteinProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
|
||||
|
||||
return {"loc": proj_mean, "scale_tril": proj_sqrt}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
|
||||
w2 = mean_part + cov_part
|
||||
return w2.mean() * self.trust_region_coeff
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = mean - old_mean
|
||||
norm = torch.norm(diff, dim=-1, keepdim=True)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
|
||||
|
||||
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = sqrt - old_sqrt
|
||||
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
||||
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
|
@ -1,6 +0,0 @@
|
||||
try:
|
||||
import cpp_projection
|
||||
except ModuleNotFoundError:
|
||||
from .base_projection_layer import ITPALExceptionLayer as KLProjectionLayer
|
||||
else:
|
||||
from .kl_projection_layer import KLProjectionLayer
|
@ -1,59 +1 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO("CartPole-v1")
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n_epochs", [5, 10])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||
ppo = PPO(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
n_epochs=n_epochs,
|
||||
gamma=gamma,
|
||||
clip_range=clip_range
|
||||
)
|
||||
assert ppo.learning_rate == learning_rate
|
||||
assert ppo.n_steps == n_steps
|
||||
assert ppo.batch_size == batch_size
|
||||
assert ppo.n_epochs == n_epochs
|
||||
assert ppo.gamma == gamma
|
||||
assert ppo.clip_range == clip_range
|
||||
|
||||
def test_ppo_predict(simple_env):
|
||||
ppo = PPO("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = ppo.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_ppo_learn():
|
||||
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = ppo.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
ppo.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = ppo.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
# TODO
|
@ -1,77 +0,0 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from fancy_rl import TRPL
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_trpl_instantiation():
|
||||
trpl = TRPL("CartPole-v1")
|
||||
assert isinstance(trpl, TRPL)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("max_kl", [0.01, 0.05])
|
||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, max_kl):
|
||||
trpl = TRPL(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
gamma=gamma,
|
||||
max_kl=max_kl
|
||||
)
|
||||
assert trpl.learning_rate == learning_rate
|
||||
assert trpl.n_steps == n_steps
|
||||
assert trpl.batch_size == batch_size
|
||||
assert trpl.gamma == gamma
|
||||
assert trpl.max_kl == max_kl
|
||||
|
||||
def test_trpl_predict(simple_env):
|
||||
trpl = TRPL("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = trpl.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_trpl_learn():
|
||||
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = trpl.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
trpl.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = trpl.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
|
||||
def test_trpl_training(simple_env):
|
||||
trpl = TRPL("CartPole-v1", total_timesteps=10000)
|
||||
|
||||
initial_performance = evaluate_policy(trpl, simple_env)
|
||||
trpl.train()
|
||||
final_performance = evaluate_policy(trpl, simple_env)
|
||||
|
||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
return total_reward / n_eval_episodes
|
@ -1,81 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
from fancy_rl import VLEARN
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_vlearn_instantiation():
|
||||
vlearn = VLEARN("CartPole-v1")
|
||||
assert isinstance(vlearn, VLEARN)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
|
||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
|
||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
|
||||
@pytest.mark.parametrize("mean_bound", [0.05, 0.1])
|
||||
@pytest.mark.parametrize("cov_bound", [0.0005, 0.001])
|
||||
def test_vlearn_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, mean_bound, cov_bound):
|
||||
vlearn = VLEARN(
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
gamma=gamma,
|
||||
mean_bound=mean_bound,
|
||||
cov_bound=cov_bound
|
||||
)
|
||||
assert vlearn.learning_rate == learning_rate
|
||||
assert vlearn.n_steps == n_steps
|
||||
assert vlearn.batch_size == batch_size
|
||||
assert vlearn.gamma == gamma
|
||||
assert vlearn.mean_bound == mean_bound
|
||||
assert vlearn.cov_bound == cov_bound
|
||||
|
||||
def test_vlearn_predict(simple_env):
|
||||
vlearn = VLEARN("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = vlearn.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_vlearn_learn():
|
||||
vlearn = VLEARN("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = vlearn.predict(obs)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
vlearn.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
loss = vlearn.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
|
||||
def test_vlearn_training(simple_env):
|
||||
vlearn = VLEARN("CartPole-v1", total_timesteps=10000)
|
||||
|
||||
initial_performance = evaluate_policy(vlearn, simple_env)
|
||||
vlearn.train()
|
||||
final_performance = evaluate_policy(vlearn, simple_env)
|
||||
|
||||
assert final_performance > initial_performance, "VLearn should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
return total_reward / n_eval_episodes
|
Loading…
Reference in New Issue
Block a user