Compare commits
8 Commits
906240e145
...
df1ba6fe53
Author | SHA1 | Date | |
---|---|---|---|
df1ba6fe53 | |||
8eb9b384c7 | |||
abc8dcbda1 | |||
e927afcc30 | |||
ca1ee980ef | |||
0c6e58634f | |||
651ef1522f | |||
71cb8593d9 |
23
README.md
23
README.md
@ -52,17 +52,34 @@ To run the test suite:
|
|||||||
pytest test/test_ppo.py
|
pytest test/test_ppo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO
|
## Status
|
||||||
|
|
||||||
|
### Implemented Features
|
||||||
|
- Proximal Policy Optimization (PPO) algorithm
|
||||||
|
- Trust Region Policy Layers (TRPL) algorithm (WIP)
|
||||||
|
- Support for continuous and discrete action spaces
|
||||||
|
- Multiple projection methods (Rewritten for MIT License Compatability):
|
||||||
|
- KL Divergence projection
|
||||||
|
- Frobenius norm projection
|
||||||
|
- Wasserstein distance projection
|
||||||
|
- Identity projection (Eq to PPO)
|
||||||
|
- Configurable neural network architectures for actor and critic
|
||||||
|
- Logging support (Terminal and WandB, extendable)
|
||||||
|
|
||||||
|
### TODO
|
||||||
|
- [ ] All PPO Tests green
|
||||||
- [ ] Better / more logging
|
- [ ] Better / more logging
|
||||||
- [ ] Test / Benchmark PPO
|
- [ ] Test / Benchmark PPO
|
||||||
- [ ] Refactor Modules for TRPL
|
- [ ] Refactor Modules for TRPL
|
||||||
- [ ] Get TRPL working
|
- [ ] Get TRPL working
|
||||||
- [ ] Test / Benchmark TRPL
|
- [ ] All TRPL Tests green
|
||||||
- [ ] Make contextual covariance optional
|
- [ ] Make contextual covariance optional
|
||||||
- [ ] Allow full-cov via chol
|
- [ ] Allow full-cov via chol
|
||||||
|
- [ ] Test / Benchmark TRPL
|
||||||
- [ ] Write docs / extend README
|
- [ ] Write docs / extend README
|
||||||
- [ ] (Implement SAC?)
|
- [ ] Test func of non-gym envs
|
||||||
|
- [ ] Implement SAC
|
||||||
|
- [ ] Implement VLEARN
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import gymnasium as gym
|
|||||||
from torchrl.envs.libs.gym import GymWrapper
|
from torchrl.envs.libs.gym import GymWrapper
|
||||||
from torchrl.record import VideoRecorder
|
from torchrl.record import VideoRecorder
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
from fancy_rl.loggers import TerminalLogger
|
from fancy_rl.loggers import TerminalLogger
|
||||||
|
|
||||||
@ -53,12 +54,11 @@ class Algo(ABC):
|
|||||||
env = GymWrapper(env).to(self.device)
|
env = GymWrapper(env).to(self.device)
|
||||||
elif callable(env_spec):
|
elif callable(env_spec):
|
||||||
env = env_spec()
|
env = env_spec()
|
||||||
if isinstance(env, gym.Env):
|
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||||
env = GymWrapper(env).to(self.device)
|
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||||
elif isinstance(env, gym.Env):
|
|
||||||
env = GymWrapper(env).to(self.device)
|
env = GymWrapper(env).to(self.device)
|
||||||
else:
|
else:
|
||||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
@ -70,6 +70,20 @@ class Algo(ABC):
|
|||||||
def evaluate(self, epoch):
|
def evaluate(self, epoch):
|
||||||
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||||
|
|
||||||
def dump_video(module):
|
def predict(
|
||||||
if isinstance(module, VideoRecorder):
|
self,
|
||||||
module.dump()
|
observation,
|
||||||
|
state=None,
|
||||||
|
deterministic=False
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||||
|
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||||
|
|
||||||
|
action_td = self.prob_actor(td)
|
||||||
|
action = action_td["action"]
|
||||||
|
|
||||||
|
# We're not using recurrent policies, so we'll always return None for the state
|
||||||
|
next_state = None
|
||||||
|
|
||||||
|
return action.squeeze(0).cpu().numpy(), next_state
|
@ -55,7 +55,7 @@ class OnPolicy(Algo):
|
|||||||
# Create collector
|
# Create collector
|
||||||
self.collector = SyncDataCollector(
|
self.collector = SyncDataCollector(
|
||||||
create_env_fn=lambda: self.make_env(eval=False),
|
create_env_fn=lambda: self.make_env(eval=False),
|
||||||
policy=self.actor,
|
policy=self.prob_actor,
|
||||||
frames_per_batch=self.n_steps,
|
frames_per_batch=self.n_steps,
|
||||||
total_frames=self.total_timesteps,
|
total_frames=self.total_timesteps,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
|
|||||||
from torchrl.objectives.value.advantages import GAE
|
from torchrl.objectives.value.advantages import GAE
|
||||||
from fancy_rl.algos.on_policy import OnPolicy
|
from fancy_rl.algos.on_policy import OnPolicy
|
||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection # Updated import
|
from fancy_rl.utils import is_discrete_space
|
||||||
|
|
||||||
class PPO(OnPolicy):
|
class PPO(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,7 +31,10 @@ class PPO(OnPolicy):
|
|||||||
device=None,
|
device=None,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
|
full_covariance=False,
|
||||||
):
|
):
|
||||||
|
self.clip_range = clip_range
|
||||||
|
|
||||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@ -41,15 +44,29 @@ class PPO(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
|
self.discrete = is_discrete_space(act_space)
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
self.actor = ProbabilisticActor(
|
|
||||||
module=actor_net,
|
if self.discrete:
|
||||||
in_keys=["loc", "scale"],
|
distribution_class = torch.distributions.Categorical
|
||||||
out_keys=["action"],
|
distribution_kwargs = {"logits": "action_logits"}
|
||||||
distribution_class=torch.distributions.Normal,
|
else:
|
||||||
return_log_prob=True
|
if full_covariance:
|
||||||
)
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
|
in_keys = ["loc", "scale_tril"]
|
||||||
|
else:
|
||||||
|
distribution_class = torch.distributions.Normal
|
||||||
|
in_keys = ["loc", "scale"]
|
||||||
|
|
||||||
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=in_keys,
|
||||||
|
out_keys=["action"]
|
||||||
|
)
|
||||||
|
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||||
@ -86,7 +103,7 @@ class PPO(OnPolicy):
|
|||||||
self.loss_module = ClipPPOLoss(
|
self.loss_module = ClipPPOLoss(
|
||||||
actor_network=self.actor,
|
actor_network=self.actor,
|
||||||
critic_network=self.critic,
|
critic_network=self.critic,
|
||||||
clip_epsilon=clip_range,
|
clip_epsilon=self.clip_range,
|
||||||
loss_critic_type='l2',
|
loss_critic_type='l2',
|
||||||
entropy_coef=self.entropy_coef,
|
entropy_coef=self.entropy_coef,
|
||||||
critic_coef=self.critic_coef,
|
critic_coef=self.critic_coef,
|
||||||
|
@ -10,7 +10,36 @@ from fancy_rl.algos.on_policy import OnPolicy
|
|||||||
from fancy_rl.policy import Actor, Critic
|
from fancy_rl.policy import Actor, Critic
|
||||||
from fancy_rl.projections import get_projection, BaseProjection
|
from fancy_rl.projections import get_projection, BaseProjection
|
||||||
from fancy_rl.objectives import TRPLLoss
|
from fancy_rl.objectives import TRPLLoss
|
||||||
|
from fancy_rl.utils import is_discrete_space
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
class ProjectedActor(TensorDictModule):
|
||||||
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
|
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
|
||||||
|
super().__init__(
|
||||||
|
module=combined_module,
|
||||||
|
in_keys=raw_actor.in_keys,
|
||||||
|
out_keys=raw_actor.out_keys
|
||||||
|
)
|
||||||
|
self.raw_actor = raw_actor
|
||||||
|
self.old_actor = old_actor
|
||||||
|
self.projection = projection
|
||||||
|
|
||||||
|
class CombinedModule(nn.Module):
|
||||||
|
def __init__(self, raw_actor, old_actor, projection):
|
||||||
|
super().__init__()
|
||||||
|
self.raw_actor = raw_actor
|
||||||
|
self.old_actor = old_actor
|
||||||
|
self.projection = projection
|
||||||
|
|
||||||
|
def forward(self, tensordict):
|
||||||
|
raw_params = self.raw_actor(tensordict)
|
||||||
|
old_params = self.old_actor(tensordict)
|
||||||
|
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||||
|
projected_params = self.projection(combined_params)
|
||||||
|
return projected_params
|
||||||
|
|
||||||
class TRPL(OnPolicy):
|
class TRPL(OnPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -40,6 +69,7 @@ class TRPL(OnPolicy):
|
|||||||
device=None,
|
device=None,
|
||||||
env_spec_eval=None,
|
env_spec_eval=None,
|
||||||
eval_episodes=10,
|
eval_episodes=10,
|
||||||
|
full_covariance=False,
|
||||||
):
|
):
|
||||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -50,8 +80,11 @@ class TRPL(OnPolicy):
|
|||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
act_space = env.action_space
|
act_space = env.action_space
|
||||||
|
|
||||||
|
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||||
|
|
||||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||||
|
|
||||||
# Handle projection_class
|
# Handle projection_class
|
||||||
if isinstance(projection_class, str):
|
if isinstance(projection_class, str):
|
||||||
@ -60,20 +93,27 @@ class TRPL(OnPolicy):
|
|||||||
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
||||||
|
|
||||||
self.projection = projection_class(
|
self.projection = projection_class(
|
||||||
in_keys=["loc", "scale"],
|
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
|
||||||
out_keys=["loc", "scale"],
|
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
|
||||||
trust_region_bound_mean=trust_region_bound_mean,
|
mean_bound=trust_region_bound_mean,
|
||||||
trust_region_bound_cov=trust_region_bound_cov
|
cov_bound=trust_region_bound_cov
|
||||||
)
|
)
|
||||||
|
|
||||||
self.actor = ProbabilisticActor(
|
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
|
||||||
module=actor_net,
|
|
||||||
in_keys=["observation"],
|
if full_covariance:
|
||||||
out_keys=["loc", "scale"],
|
distribution_class = torch.distributions.MultivariateNormal
|
||||||
distribution_class=torch.distributions.Normal,
|
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
|
||||||
return_log_prob=True
|
else:
|
||||||
|
distribution_class = torch.distributions.Normal
|
||||||
|
distribution_kwargs = {"loc": "loc", "scale": "scale"}
|
||||||
|
|
||||||
|
self.prob_actor = ProbabilisticActor(
|
||||||
|
module=self.actor,
|
||||||
|
distribution_class=distribution_class,
|
||||||
|
return_log_prob=True,
|
||||||
|
in_keys=distribution_kwargs,
|
||||||
)
|
)
|
||||||
self.old_actor = deepcopy(self.actor)
|
|
||||||
|
|
||||||
self.trust_region_coef = trust_region_coef
|
self.trust_region_coef = trust_region_coef
|
||||||
self.loss_module = TRPLLoss(
|
self.loss_module = TRPLLoss(
|
||||||
@ -88,7 +128,7 @@ class TRPL(OnPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
optimizers = {
|
optimizers = {
|
||||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
|
||||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,23 +159,7 @@ class TRPL(OnPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_old_policy(self):
|
def update_old_policy(self):
|
||||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
self.old_actor.load_state_dict(self.raw_actor.state_dict())
|
||||||
|
|
||||||
def project_policy(self, obs):
|
|
||||||
with torch.no_grad():
|
|
||||||
old_dist = self.old_actor(obs)
|
|
||||||
new_dist = self.actor(obs)
|
|
||||||
projected_params = self.projection.project(new_dist, old_dist)
|
|
||||||
return projected_params
|
|
||||||
|
|
||||||
def pre_update(self, tensordict):
|
|
||||||
obs = tensordict["observation"]
|
|
||||||
projected_dist = self.project_policy(obs)
|
|
||||||
|
|
||||||
# Update tensordict with projected distribution parameters
|
|
||||||
tensordict["projected_loc"] = projected_dist[0]
|
|
||||||
tensordict["projected_scale"] = projected_dist[1]
|
|
||||||
return tensordict
|
|
||||||
|
|
||||||
def post_update(self):
|
def post_update(self):
|
||||||
self.update_old_policy()
|
self.update_old_policy()
|
@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
|
|||||||
|
|
||||||
def _trust_region_loss(self, tensordict):
|
def _trust_region_loss(self, tensordict):
|
||||||
old_distribution = self.old_actor_network(tensordict)
|
old_distribution = self.old_actor_network(tensordict)
|
||||||
raw_distribution = self.actor_network(tensordict)
|
new_distribution = self.actor_network(tensordict)
|
||||||
return self.projection(self.actor_network, raw_distribution, old_distribution)
|
return self.projection.get_trust_region_loss(new_distribution, old_distribution)
|
||||||
|
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
||||||
tensordict = tensordict.clone(False)
|
tensordict = tensordict.clone(False)
|
||||||
|
@ -3,31 +3,63 @@ from tensordict.nn import TensorDictModule
|
|||||||
from torchrl.modules import MLP
|
from torchrl.modules import MLP
|
||||||
from tensordict.nn.distributions import NormalParamExtractor
|
from tensordict.nn.distributions import NormalParamExtractor
|
||||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
class Actor(TensorDictModule):
|
class Actor(TensorDictModule):
|
||||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||||
|
self.discrete = is_discrete_space(act_space)
|
||||||
act_space_shape = get_space_shape(act_space)
|
act_space_shape = get_space_shape(act_space)
|
||||||
if is_discrete_space(act_space):
|
|
||||||
out_features = act_space_shape[-1]
|
|
||||||
else:
|
|
||||||
out_features = act_space_shape[-1] * 2
|
|
||||||
|
|
||||||
actor_module = nn.Sequential(
|
if self.discrete and full_covariance:
|
||||||
MLP(
|
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||||
in_features=get_space_shape(obs_space)[-1],
|
|
||||||
out_features=out_features,
|
self.full_covariance = full_covariance
|
||||||
num_cells=hidden_sizes,
|
|
||||||
activation_class=getattr(nn, activation_fn),
|
if self.discrete:
|
||||||
device=device
|
out_features = act_space_shape[-1]
|
||||||
),
|
out_keys = ["action_logits"]
|
||||||
NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
|
else:
|
||||||
|
if full_covariance:
|
||||||
|
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
||||||
|
out_keys = ["loc", "scale_tril"]
|
||||||
|
else:
|
||||||
|
out_features = act_space_shape[-1] * 2
|
||||||
|
out_keys = ["loc", "scale"]
|
||||||
|
|
||||||
|
actor_module = MLP(
|
||||||
|
in_features=get_space_shape(obs_space)[-1],
|
||||||
|
out_features=out_features,
|
||||||
|
num_cells=hidden_sizes,
|
||||||
|
activation_class=getattr(nn, activation_fn),
|
||||||
|
device=device
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
|
if not self.discrete:
|
||||||
|
if full_covariance:
|
||||||
|
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
||||||
|
else:
|
||||||
|
param_extractor = NormalParamExtractor()
|
||||||
|
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
module=actor_module,
|
module=actor_module,
|
||||||
in_keys=["observation"],
|
in_keys=["observation"],
|
||||||
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
|
out_keys=out_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class FullCovarianceNormalParamExtractor(nn.Module):
|
||||||
|
def __init__(self, action_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.action_dim = action_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
loc = x[:, :self.action_dim]
|
||||||
|
scale_tril = torch.zeros(x.shape[0], self.action_dim, self.action_dim, device=x.device)
|
||||||
|
tril_indices = torch.tril_indices(row=self.action_dim, col=self.action_dim, offset=0)
|
||||||
|
scale_tril[:, tril_indices[0], tril_indices[1]] = x[:, self.action_dim:]
|
||||||
|
scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
|
||||||
|
return TensorDict({"loc": loc, "scale_tril": scale_tril}, batch_size=x.shape[0])
|
||||||
|
|
||||||
class Critic(TensorDictModule):
|
class Critic(TensorDictModule):
|
||||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||||
critic_module = MLP(
|
critic_module = MLP(
|
||||||
|
@ -1,16 +1,71 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict
|
from torch import nn
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
class BaseProjection(ABC, torch.nn.Module):
|
class BaseProjection(nn.Module, ABC):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._validate_in_keys(in_keys)
|
||||||
|
self._validate_out_keys(out_keys)
|
||||||
self.in_keys = in_keys
|
self.in_keys = in_keys
|
||||||
self.out_keys = out_keys
|
self.out_keys = out_keys
|
||||||
|
self.trust_region_coeff = trust_region_coeff
|
||||||
|
self.mean_bound = mean_bound
|
||||||
|
self.cov_bound = cov_bound
|
||||||
|
self.full_cov = "scale_tril" in in_keys
|
||||||
|
self.contextual_std = contextual_std
|
||||||
|
|
||||||
|
def _validate_in_keys(self, keys: List[str]):
|
||||||
|
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
|
||||||
|
if not set(keys).issubset(valid_keys):
|
||||||
|
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
|
||||||
|
if "loc" not in keys or "old_loc" not in keys:
|
||||||
|
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
|
||||||
|
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
|
||||||
|
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
|
||||||
|
|
||||||
|
def _validate_out_keys(self, keys: List[str]):
|
||||||
|
valid_keys = {"loc", "scale", "scale_tril"}
|
||||||
|
if not set(keys).issubset(valid_keys):
|
||||||
|
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
|
||||||
|
if "loc" not in keys:
|
||||||
|
raise ValueError("'loc' must be included in out_keys")
|
||||||
|
if "scale" not in keys and "scale_tril" not in keys:
|
||||||
|
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
@abstractmethod
|
||||||
return self.project(policy_params, old_policy_params)
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, tensordict):
|
||||||
|
policy_params = {}
|
||||||
|
old_policy_params = {}
|
||||||
|
|
||||||
|
for key in self.in_keys:
|
||||||
|
if key not in tensordict:
|
||||||
|
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
|
||||||
|
|
||||||
|
if key.startswith("old_"):
|
||||||
|
old_policy_params[key[4:]] = tensordict[key]
|
||||||
|
else:
|
||||||
|
policy_params[key] = tensordict[key]
|
||||||
|
|
||||||
|
projected_params = self.project(policy_params, old_policy_params)
|
||||||
|
return projected_params
|
||||||
|
|
||||||
|
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
if not self.full_cov:
|
||||||
|
return torch.diag_embed(params["scale"].pow(2))
|
||||||
|
else:
|
||||||
|
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
|
||||||
|
|
||||||
|
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
|
||||||
|
if not self.full_cov:
|
||||||
|
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
|
||||||
|
else:
|
||||||
|
return torch.linalg.cholesky(cov)
|
@ -1,33 +1,34 @@
|
|||||||
import torch
|
import torch
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
class FrobeniusProjection(BaseProjection):
|
class FrobeniusProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.scale_prec = scale_prec
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean = old_policy_params["loc"]
|
||||||
|
|
||||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
cov = self._calc_covariance(policy_params)
|
||||||
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
|
old_cov = self._calc_covariance(old_policy_params)
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||||
|
|
||||||
proj_chol = torch.linalg.cholesky(proj_cov)
|
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||||
return {"loc": proj_mean, "scale_tril": proj_chol}
|
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean = proj_policy_params["loc"]
|
||||||
|
|
||||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
cov = self._calc_covariance(policy_params)
|
||||||
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
|
proj_cov = self._calc_covariance(proj_policy_params)
|
||||||
|
|
||||||
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
||||||
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
||||||
|
@ -3,8 +3,8 @@ from .base_projection import BaseProjection
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
class IdentityProjection(BaseProjection):
|
class IdentityProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
return policy_params
|
return policy_params
|
||||||
|
@ -2,6 +2,7 @@ import torch
|
|||||||
import cpp_projection
|
import cpp_projection
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict, Tuple, Any
|
from typing import Dict, Tuple, Any
|
||||||
|
|
||||||
MAX_EVAL = 1000
|
MAX_EVAL = 1000
|
||||||
@ -10,57 +11,65 @@ def get_numpy(tensor):
|
|||||||
return tensor.detach().cpu().numpy()
|
return tensor.detach().cpu().numpy()
|
||||||
|
|
||||||
class KLProjection(BaseProjection):
|
class KLProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.is_diag = is_diag
|
|
||||||
self.contextual_std = contextual_std
|
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||||
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
|
||||||
|
|
||||||
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
|
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
|
||||||
|
|
||||||
if not self.contextual_std:
|
if not self.contextual_std:
|
||||||
std = std[:1]
|
scale_or_tril = scale_or_tril[:1]
|
||||||
old_std = old_std[:1]
|
old_scale_or_tril = old_scale_or_tril[:1]
|
||||||
cov_part = cov_part[:1]
|
cov_part = cov_part[:1]
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_std = self._cov_projection(std, old_std, cov_part)
|
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
|
||||||
|
|
||||||
if not self.contextual_std:
|
if not self.contextual_std:
|
||||||
proj_std = proj_std.expand(mean.shape[0], -1, -1)
|
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale_tril": proj_std}
|
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||||
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
|
||||||
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
|
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||||
return kl.mean() * self.trust_region_coeff
|
return kl.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
mean, std = p
|
mean, scale_or_tril = p
|
||||||
mean_other, std_other = q
|
mean_other, scale_or_tril_other = q
|
||||||
k = mean.shape[-1]
|
k = mean.shape[-1]
|
||||||
|
|
||||||
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
|
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
|
||||||
|
|
||||||
det_term = self._log_determinant(std)
|
det_term = self._log_determinant(scale_or_tril)
|
||||||
det_term_other = self._log_determinant(std_other)
|
det_term_other = self._log_determinant(scale_or_tril_other)
|
||||||
|
|
||||||
|
if self.full_cov:
|
||||||
|
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
|
||||||
|
else:
|
||||||
|
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
|
||||||
|
|
||||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
|
|
||||||
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
||||||
|
|
||||||
return maha_part, cov_part
|
return maha_part, cov_part
|
||||||
|
|
||||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||||
diff = x - y
|
diff = x - y
|
||||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
|
if self.full_cov:
|
||||||
|
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
|
||||||
|
else:
|
||||||
|
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
|
||||||
|
|
||||||
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
|
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||||
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
if self.full_cov:
|
||||||
|
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||||
|
else:
|
||||||
|
return 2 * torch.log(scale_or_tril).sum(-1)
|
||||||
|
|
||||||
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.sum(x.pow(2), dim=(-2, -1))
|
return torch.sum(x.pow(2), dim=(-2, -1))
|
||||||
@ -68,49 +77,45 @@ class KLProjection(BaseProjection):
|
|||||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||||
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
||||||
|
|
||||||
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||||
cov = torch.matmul(std, std.transpose(-1, -2))
|
if self.full_cov:
|
||||||
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
|
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
|
||||||
|
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
|
||||||
if self.is_diag:
|
|
||||||
mask = cov_part > self.cov_bound
|
|
||||||
proj_std = torch.zeros_like(std)
|
|
||||||
proj_std[~mask] = std[~mask]
|
|
||||||
try:
|
|
||||||
if mask.any():
|
|
||||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
|
|
||||||
old_cov.diagonal(dim1=-2, dim2=-1),
|
|
||||||
self.cov_bound)
|
|
||||||
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
|
||||||
if is_invalid.any():
|
|
||||||
proj_std[is_invalid] = old_std[is_invalid]
|
|
||||||
mask &= ~is_invalid
|
|
||||||
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
|
|
||||||
except Exception as e:
|
|
||||||
proj_std = old_std
|
|
||||||
else:
|
else:
|
||||||
try:
|
cov = scale_or_tril.pow(2)
|
||||||
mask = cov_part > self.cov_bound
|
old_cov = old_scale_or_tril.pow(2)
|
||||||
proj_std = torch.zeros_like(std)
|
|
||||||
proj_std[~mask] = std[~mask]
|
mask = cov_part > self.cov_bound
|
||||||
if mask.any():
|
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
|
||||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
|
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if mask.any():
|
||||||
|
if self.full_cov:
|
||||||
|
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
|
||||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||||
if is_invalid.any():
|
if is_invalid.any():
|
||||||
proj_std[is_invalid] = old_std[is_invalid]
|
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||||
mask &= ~is_invalid
|
mask &= ~is_invalid
|
||||||
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||||
failed_mask = failed_mask.bool()
|
failed_mask = failed_mask.bool()
|
||||||
if failed_mask.any():
|
if failed_mask.any():
|
||||||
proj_std[failed_mask] = old_std[failed_mask]
|
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
|
||||||
except Exception as e:
|
else:
|
||||||
import logging
|
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
|
||||||
logging.error('Projection failed, taking old cholesky for projection.')
|
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
||||||
print("Projection failed, taking old cholesky for projection.")
|
if is_invalid.any():
|
||||||
proj_std = old_std
|
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||||
raise e
|
mask &= ~is_invalid
|
||||||
|
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
logging.error('Projection failed, taking old scale_or_tril for projection.')
|
||||||
|
print("Projection failed, taking old scale_or_tril for projection.")
|
||||||
|
proj_scale_or_tril = old_scale_or_tril
|
||||||
|
raise e
|
||||||
|
|
||||||
return proj_std
|
return proj_scale_or_tril
|
||||||
|
|
||||||
|
|
||||||
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
||||||
|
@ -1,56 +1,86 @@
|
|||||||
import torch
|
import torch
|
||||||
from .base_projection import BaseProjection
|
from .base_projection import BaseProjection
|
||||||
|
from tensordict.nn import TensorDictModule
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
'Converts' scale_tril to scale_sqrt.
|
||||||
|
|
||||||
|
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
|
||||||
|
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
|
||||||
|
"""
|
||||||
|
return scale_tril
|
||||||
|
|
||||||
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
||||||
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
mean, sqrt = p
|
mean, scale_or_sqrt = p
|
||||||
mean_other, sqrt_other = q
|
mean_other, scale_or_sqrt_other = q
|
||||||
|
|
||||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||||
|
|
||||||
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
|
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
|
||||||
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
|
cov = scale_or_sqrt.pow(2)
|
||||||
|
cov_other = scale_or_sqrt_other.pow(2)
|
||||||
if scale_prec:
|
if scale_prec:
|
||||||
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
|
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||||
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
|
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||||
c = sqrt_inv_other @ cov @ sqrt_inv_other
|
c = sqrt_inv_other.pow(2) * cov
|
||||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
|
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
|
||||||
else:
|
else:
|
||||||
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
|
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
|
||||||
|
else: # Full covariance case
|
||||||
|
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||||
|
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
|
||||||
|
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
|
||||||
|
if scale_prec:
|
||||||
|
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||||
|
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
|
||||||
|
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
|
||||||
|
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||||
|
else:
|
||||||
|
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||||
|
|
||||||
return mean_part, cov_part
|
return mean_part, cov_part
|
||||||
|
|
||||||
class WassersteinProjection(BaseProjection):
|
class WassersteinProjection(BaseProjection):
|
||||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||||
self.scale_prec = scale_prec
|
self.scale_prec = scale_prec
|
||||||
|
|
||||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
|
old_mean = old_policy_params["loc"]
|
||||||
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||||
|
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
|
||||||
|
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
|
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
|
||||||
|
|
||||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||||
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
|
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||||
|
|
||||||
return {"loc": proj_mean, "scale_tril": proj_sqrt}
|
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
|
||||||
|
|
||||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
mean = policy_params["loc"]
|
||||||
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
proj_mean = proj_policy_params["loc"]
|
||||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
|
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||||
|
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
|
||||||
|
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
|
||||||
w2 = mean_part + cov_part
|
w2 = mean_part + cov_part
|
||||||
return w2.mean() * self.trust_region_coeff
|
return w2.mean() * self.trust_region_coeff
|
||||||
|
|
||||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||||
diff = mean - old_mean
|
diff = mean - old_mean
|
||||||
norm = torch.norm(diff, dim=-1, keepdim=True)
|
norm = torch.sqrt(mean_part)
|
||||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
|
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||||
|
|
||||||
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||||
diff = sqrt - old_sqrt
|
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
|
||||||
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
|
norm = torch.sqrt(cov_part)
|
||||||
|
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
|
||||||
|
else: # Full covariance case
|
||||||
|
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||||
|
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
||||||
|
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)
|
@ -33,24 +33,28 @@ def is_discrete_space(action_space):
|
|||||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||||
|
|
||||||
def get_space_shape(action_space):
|
def get_space_shape(action_space):
|
||||||
if gym_available:
|
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||||
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||||
continuous_types = (GymBox,)
|
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||||
else:
|
|
||||||
discrete_types = ()
|
|
||||||
continuous_types = ()
|
|
||||||
|
|
||||||
discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
if gym_available:
|
||||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||||
continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
continuous_types += (GymBox,)
|
||||||
|
|
||||||
if isinstance(action_space, discrete_types):
|
if isinstance(action_space, discrete_types):
|
||||||
if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||||
return (action_space.n,)
|
return (action_space.n,)
|
||||||
elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||||
return (sum(action_space.nvec),)
|
return (sum(action_space.nvec),)
|
||||||
elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||||
return (action_space.n,)
|
return (action_space.n,)
|
||||||
|
elif gym_available:
|
||||||
|
if isinstance(action_space, GymDiscrete):
|
||||||
|
return (action_space.n,)
|
||||||
|
elif isinstance(action_space, GymMultiDiscrete):
|
||||||
|
return (sum(action_space.nvec),)
|
||||||
|
elif isinstance(action_space, GymMultiBinary):
|
||||||
|
return (action_space.n,)
|
||||||
elif isinstance(action_space, continuous_types):
|
elif isinstance(action_space, continuous_types):
|
||||||
return action_space.shape
|
return action_space.shape
|
||||||
|
|
||||||
|
@ -3,12 +3,15 @@ import numpy as np
|
|||||||
from fancy_rl import PPO
|
from fancy_rl import PPO
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('CartPole-v1')
|
return gym.make('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_ppo_instantiation():
|
def test_ppo_instantiation():
|
||||||
ppo = PPO("CartPole-v1")
|
ppo = PPO(simple_env)
|
||||||
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
|
def test_ppo_instantiation_from_str():
|
||||||
|
ppo = PPO('CartPole-v1')
|
||||||
assert isinstance(ppo, PPO)
|
assert isinstance(ppo, PPO)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@ -19,7 +22,7 @@ def test_ppo_instantiation():
|
|||||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||||
ppo = PPO(
|
ppo = PPO(
|
||||||
"CartPole-v1",
|
simple_env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -34,26 +37,42 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
|
|||||||
assert ppo.gamma == gamma
|
assert ppo.gamma == gamma
|
||||||
assert ppo.clip_range == clip_range
|
assert ppo.clip_range == clip_range
|
||||||
|
|
||||||
def test_ppo_predict(simple_env):
|
def test_ppo_predict():
|
||||||
ppo = PPO("CartPole-v1")
|
ppo = PPO(simple_env)
|
||||||
obs, _ = simple_env.reset()
|
env = ppo.make_env()
|
||||||
|
obs, _ = env.reset()
|
||||||
action, _ = ppo.predict(obs)
|
action, _ = ppo.predict(obs)
|
||||||
assert isinstance(action, np.ndarray)
|
assert isinstance(action, np.ndarray)
|
||||||
assert action.shape == simple_env.action_space.shape
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
def test_ppo_learn():
|
def test_ppo_learn():
|
||||||
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
|
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||||
env = gym.make("CartPole-v1")
|
env = ppo.make_env()
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = ppo.predict(obs)
|
action, _ = ppo.predict(obs)
|
||||||
next_obs, reward, done, truncated, _ = env.step(action)
|
obs, reward, done, truncated, _ = env.step(action)
|
||||||
ppo.store_transition(obs, action, reward, done, next_obs)
|
|
||||||
obs = next_obs
|
|
||||||
if done or truncated:
|
if done or truncated:
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
|
|
||||||
loss = ppo.learn()
|
def test_ppo_training():
|
||||||
assert isinstance(loss, dict)
|
ppo = PPO(simple_env, total_timesteps=10000)
|
||||||
assert "policy_loss" in loss
|
env = ppo.make_env()
|
||||||
assert "value_loss" in loss
|
|
||||||
|
initial_performance = evaluate_policy(ppo, env)
|
||||||
|
ppo.train()
|
||||||
|
final_performance = evaluate_policy(ppo, env)
|
||||||
|
|
||||||
|
assert final_performance > initial_performance, "PPO should improve performance after training"
|
||||||
|
|
||||||
|
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||||
|
total_reward = 0
|
||||||
|
for _ in range(n_eval_episodes):
|
||||||
|
obs, _ = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _ = policy.predict(obs)
|
||||||
|
obs, reward, terminated, truncated, _ = env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
done = terminated or truncated
|
||||||
|
return total_reward / n_eval_episodes
|
@ -3,12 +3,15 @@ import numpy as np
|
|||||||
from fancy_rl import TRPL
|
from fancy_rl import TRPL
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def simple_env():
|
def simple_env():
|
||||||
return gym.make('CartPole-v1')
|
return gym.make('LunarLander-v2', continuous=True)
|
||||||
|
|
||||||
def test_trpl_instantiation():
|
def test_trpl_instantiation():
|
||||||
trpl = TRPL("CartPole-v1")
|
trpl = TRPL(simple_env)
|
||||||
|
assert isinstance(trpl, TRPL)
|
||||||
|
|
||||||
|
def test_trpl_instantiation_from_str():
|
||||||
|
trpl = TRPL('MountainCarContinuous-v0')
|
||||||
assert isinstance(trpl, TRPL)
|
assert isinstance(trpl, TRPL)
|
||||||
|
|
||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||||
@ -19,7 +22,7 @@ def test_trpl_instantiation():
|
|||||||
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
||||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
||||||
trpl = TRPL(
|
trpl = TRPL(
|
||||||
"CartPole-v1",
|
simple_env,
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
n_steps=n_steps,
|
n_steps=n_steps,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@ -34,16 +37,17 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
|
|||||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||||
|
|
||||||
def test_trpl_predict(simple_env):
|
def test_trpl_predict():
|
||||||
trpl = TRPL("CartPole-v1")
|
trpl = TRPL(simple_env)
|
||||||
obs, _ = simple_env.reset()
|
env = trpl.make_env()
|
||||||
|
obs, _ = env.reset()
|
||||||
action, _ = trpl.predict(obs)
|
action, _ = trpl.predict(obs)
|
||||||
assert isinstance(action, np.ndarray)
|
assert isinstance(action, np.ndarray)
|
||||||
assert action.shape == simple_env.action_space.shape
|
assert action.shape == env.action_space.shape
|
||||||
|
|
||||||
def test_trpl_learn():
|
def test_trpl_learn():
|
||||||
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
|
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
|
||||||
env = gym.make("CartPole-v1")
|
env = trpl.make_env()
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
for _ in range(64):
|
for _ in range(64):
|
||||||
action, _ = trpl.predict(obs)
|
action, _ = trpl.predict(obs)
|
||||||
@ -58,12 +62,13 @@ def test_trpl_learn():
|
|||||||
assert "policy_loss" in loss
|
assert "policy_loss" in loss
|
||||||
assert "value_loss" in loss
|
assert "value_loss" in loss
|
||||||
|
|
||||||
def test_trpl_training(simple_env):
|
def test_trpl_training():
|
||||||
trpl = TRPL("CartPole-v1", total_timesteps=10000)
|
trpl = TRPL(simple_env, total_timesteps=10000)
|
||||||
|
env = trpl.make_env()
|
||||||
|
|
||||||
initial_performance = evaluate_policy(trpl, simple_env)
|
initial_performance = evaluate_policy(trpl, env)
|
||||||
trpl.train()
|
trpl.train()
|
||||||
final_performance = evaluate_policy(trpl, simple_env)
|
final_performance = evaluate_policy(trpl, env)
|
||||||
|
|
||||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user