Compare commits
No commits in common. "df1ba6fe53c1fd3c2b490a28df0e586a49a48df8" and "906240e1458d0f3422c9a82f52af15b4cc778b5e" have entirely different histories.
df1ba6fe53
...
906240e145
23
README.md
23
README.md
@ -52,34 +52,17 @@ To run the test suite:
|
||||
pytest test/test_ppo.py
|
||||
```
|
||||
|
||||
## Status
|
||||
## TODO
|
||||
|
||||
### Implemented Features
|
||||
- Proximal Policy Optimization (PPO) algorithm
|
||||
- Trust Region Policy Layers (TRPL) algorithm (WIP)
|
||||
- Support for continuous and discrete action spaces
|
||||
- Multiple projection methods (Rewritten for MIT License Compatability):
|
||||
- KL Divergence projection
|
||||
- Frobenius norm projection
|
||||
- Wasserstein distance projection
|
||||
- Identity projection (Eq to PPO)
|
||||
- Configurable neural network architectures for actor and critic
|
||||
- Logging support (Terminal and WandB, extendable)
|
||||
|
||||
### TODO
|
||||
- [ ] All PPO Tests green
|
||||
- [ ] Better / more logging
|
||||
- [ ] Test / Benchmark PPO
|
||||
- [ ] Refactor Modules for TRPL
|
||||
- [ ] Get TRPL working
|
||||
- [ ] All TRPL Tests green
|
||||
- [ ] Test / Benchmark TRPL
|
||||
- [ ] Make contextual covariance optional
|
||||
- [ ] Allow full-cov via chol
|
||||
- [ ] Test / Benchmark TRPL
|
||||
- [ ] Write docs / extend README
|
||||
- [ ] Test func of non-gym envs
|
||||
- [ ] Implement SAC
|
||||
- [ ] Implement VLEARN
|
||||
- [ ] (Implement SAC?)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -3,7 +3,6 @@ import gymnasium as gym
|
||||
from torchrl.envs.libs.gym import GymWrapper
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
from tensordict import TensorDict
|
||||
|
||||
from fancy_rl.loggers import TerminalLogger
|
||||
|
||||
@ -54,11 +53,12 @@ class Algo(ABC):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif callable(env_spec):
|
||||
env = env_spec()
|
||||
if not (isinstance(env, gym.Env) or isinstance(env, gym.core.Wrapper)):
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a callable that returned a {}".format(type(env)))
|
||||
if isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
elif isinstance(env, gym.Env):
|
||||
env = GymWrapper(env).to(self.device)
|
||||
else:
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment. Was a {}".format(type(env_spec)))
|
||||
raise ValueError("env_spec must be a string or a callable that returns an environment.")
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -70,20 +70,6 @@ class Algo(ABC):
|
||||
def evaluate(self, epoch):
|
||||
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation,
|
||||
state=None,
|
||||
deterministic=False
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor}, batch_size=[1])
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
|
||||
# We're not using recurrent policies, so we'll always return None for the state
|
||||
next_state = None
|
||||
|
||||
return action.squeeze(0).cpu().numpy(), next_state
|
||||
def dump_video(module):
|
||||
if isinstance(module, VideoRecorder):
|
||||
module.dump()
|
@ -55,7 +55,7 @@ class OnPolicy(Algo):
|
||||
# Create collector
|
||||
self.collector = SyncDataCollector(
|
||||
create_env_fn=lambda: self.make_env(eval=False),
|
||||
policy=self.prob_actor,
|
||||
policy=self.actor,
|
||||
frames_per_batch=self.n_steps,
|
||||
total_frames=self.total_timesteps,
|
||||
device=self.device,
|
||||
|
@ -4,7 +4,7 @@ from torchrl.objectives import ClipPPOLoss
|
||||
from torchrl.objectives.value.advantages import GAE
|
||||
from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
from fancy_rl.projections import get_projection # Updated import
|
||||
|
||||
class PPO(OnPolicy):
|
||||
def __init__(
|
||||
@ -31,10 +31,7 @@ class PPO(OnPolicy):
|
||||
device=None,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
full_covariance=False,
|
||||
):
|
||||
self.clip_range = clip_range
|
||||
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
|
||||
@ -44,28 +41,14 @@ class PPO(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
distribution_kwargs = {"logits": "action_logits"}
|
||||
else:
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
in_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
distribution_class = torch.distributions.Normal
|
||||
in_keys = ["loc", "scale"]
|
||||
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=in_keys,
|
||||
out_keys=["action"]
|
||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||
self.actor = ProbabilisticActor(
|
||||
module=actor_net,
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["action"],
|
||||
distribution_class=torch.distributions.Normal,
|
||||
return_log_prob=True
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
@ -103,7 +86,7 @@ class PPO(OnPolicy):
|
||||
self.loss_module = ClipPPOLoss(
|
||||
actor_network=self.actor,
|
||||
critic_network=self.critic,
|
||||
clip_epsilon=self.clip_range,
|
||||
clip_epsilon=clip_range,
|
||||
loss_critic_type='l2',
|
||||
entropy_coef=self.entropy_coef,
|
||||
critic_coef=self.critic_coef,
|
||||
|
@ -10,36 +10,7 @@ from fancy_rl.algos.on_policy import OnPolicy
|
||||
from fancy_rl.policy import Actor, Critic
|
||||
from fancy_rl.projections import get_projection, BaseProjection
|
||||
from fancy_rl.objectives import TRPLLoss
|
||||
from fancy_rl.utils import is_discrete_space
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
|
||||
class ProjectedActor(TensorDictModule):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
combined_module = self.CombinedModule(raw_actor, old_actor, projection)
|
||||
super().__init__(
|
||||
module=combined_module,
|
||||
in_keys=raw_actor.in_keys,
|
||||
out_keys=raw_actor.out_keys
|
||||
)
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
|
||||
class CombinedModule(nn.Module):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
super().__init__()
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, tensordict):
|
||||
raw_params = self.raw_actor(tensordict)
|
||||
old_params = self.old_actor(tensordict)
|
||||
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||
projected_params = self.projection(combined_params)
|
||||
return projected_params
|
||||
|
||||
class TRPL(OnPolicy):
|
||||
def __init__(
|
||||
@ -69,7 +40,6 @@ class TRPL(OnPolicy):
|
||||
device=None,
|
||||
env_spec_eval=None,
|
||||
eval_episodes=10,
|
||||
full_covariance=False,
|
||||
):
|
||||
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.device = device
|
||||
@ -80,11 +50,8 @@ class TRPL(OnPolicy):
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not is_discrete_space(act_space), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
actor_net = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device)
|
||||
|
||||
# Handle projection_class
|
||||
if isinstance(projection_class, str):
|
||||
@ -93,27 +60,20 @@ class TRPL(OnPolicy):
|
||||
raise ValueError("projection_class must be a string or a subclass of BaseProjection")
|
||||
|
||||
self.projection = projection_class(
|
||||
in_keys=["loc", "scale_tril", "old_loc", "old_scale_tril"] if full_covariance else ["loc", "scale", "old_loc", "old_scale"],
|
||||
out_keys=["loc", "scale_tril"] if full_covariance else ["loc", "scale"],
|
||||
mean_bound=trust_region_bound_mean,
|
||||
cov_bound=trust_region_bound_cov
|
||||
in_keys=["loc", "scale"],
|
||||
out_keys=["loc", "scale"],
|
||||
trust_region_bound_mean=trust_region_bound_mean,
|
||||
trust_region_bound_cov=trust_region_bound_cov
|
||||
)
|
||||
|
||||
self.actor = ProjectedActor(self.raw_actor, self.old_actor, self.projection)
|
||||
|
||||
if full_covariance:
|
||||
distribution_class = torch.distributions.MultivariateNormal
|
||||
distribution_kwargs = {"loc": "loc", "scale_tril": "scale_tril"}
|
||||
else:
|
||||
distribution_class = torch.distributions.Normal
|
||||
distribution_kwargs = {"loc": "loc", "scale": "scale"}
|
||||
|
||||
self.prob_actor = ProbabilisticActor(
|
||||
module=self.actor,
|
||||
distribution_class=distribution_class,
|
||||
return_log_prob=True,
|
||||
in_keys=distribution_kwargs,
|
||||
self.actor = ProbabilisticActor(
|
||||
module=actor_net,
|
||||
in_keys=["observation"],
|
||||
out_keys=["loc", "scale"],
|
||||
distribution_class=torch.distributions.Normal,
|
||||
return_log_prob=True
|
||||
)
|
||||
self.old_actor = deepcopy(self.actor)
|
||||
|
||||
self.trust_region_coef = trust_region_coef
|
||||
self.loss_module = TRPLLoss(
|
||||
@ -128,7 +88,7 @@ class TRPL(OnPolicy):
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": torch.optim.Adam(self.raw_actor.parameters(), lr=learning_rate),
|
||||
"actor": torch.optim.Adam(self.actor.parameters(), lr=learning_rate),
|
||||
"critic": torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
|
||||
}
|
||||
|
||||
@ -159,7 +119,23 @@ class TRPL(OnPolicy):
|
||||
)
|
||||
|
||||
def update_old_policy(self):
|
||||
self.old_actor.load_state_dict(self.raw_actor.state_dict())
|
||||
self.old_actor.load_state_dict(self.actor.state_dict())
|
||||
|
||||
def project_policy(self, obs):
|
||||
with torch.no_grad():
|
||||
old_dist = self.old_actor(obs)
|
||||
new_dist = self.actor(obs)
|
||||
projected_params = self.projection.project(new_dist, old_dist)
|
||||
return projected_params
|
||||
|
||||
def pre_update(self, tensordict):
|
||||
obs = tensordict["observation"]
|
||||
projected_dist = self.project_policy(obs)
|
||||
|
||||
# Update tensordict with projected distribution parameters
|
||||
tensordict["projected_loc"] = projected_dist[0]
|
||||
tensordict["projected_scale"] = projected_dist[1]
|
||||
return tensordict
|
||||
|
||||
def post_update(self):
|
||||
self.update_old_policy()
|
@ -83,8 +83,8 @@ class TRPLLoss(PPOLoss):
|
||||
|
||||
def _trust_region_loss(self, tensordict):
|
||||
old_distribution = self.old_actor_network(tensordict)
|
||||
new_distribution = self.actor_network(tensordict)
|
||||
return self.projection.get_trust_region_loss(new_distribution, old_distribution)
|
||||
raw_distribution = self.actor_network(tensordict)
|
||||
return self.projection(self.actor_network, raw_distribution, old_distribution)
|
||||
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDict:
|
||||
tensordict = tensordict.clone(False)
|
||||
|
@ -3,63 +3,31 @@ from tensordict.nn import TensorDictModule
|
||||
from torchrl.modules import MLP
|
||||
from tensordict.nn.distributions import NormalParamExtractor
|
||||
from fancy_rl.utils import is_discrete_space, get_space_shape
|
||||
from tensordict import TensorDict
|
||||
|
||||
class Actor(TensorDictModule):
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||
self.discrete = is_discrete_space(act_space)
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
|
||||
act_space_shape = get_space_shape(act_space)
|
||||
|
||||
if self.discrete and full_covariance:
|
||||
raise ValueError("Full covariance is not applicable for discrete action spaces.")
|
||||
|
||||
self.full_covariance = full_covariance
|
||||
|
||||
if self.discrete:
|
||||
if is_discrete_space(act_space):
|
||||
out_features = act_space_shape[-1]
|
||||
out_keys = ["action_logits"]
|
||||
else:
|
||||
if full_covariance:
|
||||
out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
|
||||
out_keys = ["loc", "scale_tril"]
|
||||
else:
|
||||
out_features = act_space_shape[-1] * 2
|
||||
out_keys = ["loc", "scale"]
|
||||
|
||||
actor_module = MLP(
|
||||
actor_module = nn.Sequential(
|
||||
MLP(
|
||||
in_features=get_space_shape(obs_space)[-1],
|
||||
out_features=out_features,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
device=device
|
||||
),
|
||||
NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
|
||||
).to(device)
|
||||
|
||||
if not self.discrete:
|
||||
if full_covariance:
|
||||
param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
|
||||
else:
|
||||
param_extractor = NormalParamExtractor()
|
||||
actor_module = nn.Sequential(actor_module, param_extractor)
|
||||
|
||||
super().__init__(
|
||||
module=actor_module,
|
||||
in_keys=["observation"],
|
||||
out_keys=out_keys
|
||||
out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
|
||||
)
|
||||
|
||||
class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
def __init__(self, action_dim):
|
||||
super().__init__()
|
||||
self.action_dim = action_dim
|
||||
|
||||
def forward(self, x):
|
||||
loc = x[:, :self.action_dim]
|
||||
scale_tril = torch.zeros(x.shape[0], self.action_dim, self.action_dim, device=x.device)
|
||||
tril_indices = torch.tril_indices(row=self.action_dim, col=self.action_dim, offset=0)
|
||||
scale_tril[:, tril_indices[0], tril_indices[1]] = x[:, self.action_dim:]
|
||||
scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
|
||||
return TensorDict({"loc": loc, "scale_tril": scale_tril}, batch_size=x.shape[0])
|
||||
|
||||
class Critic(TensorDictModule):
|
||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||
critic_module = MLP(
|
||||
|
@ -1,71 +1,16 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
class BaseProjection(nn.Module, ABC):
|
||||
def __init__(self, in_keys: List[str], out_keys: List[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||
class BaseProjection(ABC, torch.nn.Module):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str]):
|
||||
super().__init__()
|
||||
self._validate_in_keys(in_keys)
|
||||
self._validate_out_keys(out_keys)
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = out_keys
|
||||
self.trust_region_coeff = trust_region_coeff
|
||||
self.mean_bound = mean_bound
|
||||
self.cov_bound = cov_bound
|
||||
self.full_cov = "scale_tril" in in_keys
|
||||
self.contextual_std = contextual_std
|
||||
|
||||
def _validate_in_keys(self, keys: List[str]):
|
||||
valid_keys = {"loc", "scale", "scale_tril", "old_loc", "old_scale", "old_scale_tril"}
|
||||
if not set(keys).issubset(valid_keys):
|
||||
raise ValueError(f"Invalid in_keys: {keys}. Must be a subset of {valid_keys}")
|
||||
if "loc" not in keys or "old_loc" not in keys:
|
||||
raise ValueError("Both 'loc' and 'old_loc' must be included in in_keys")
|
||||
if ("scale" in keys) != ("old_scale" in keys) or ("scale_tril" in keys) != ("old_scale_tril" in keys):
|
||||
raise ValueError("in_keys must have matching 'scale'/'old_scale' or 'scale_tril'/'old_scale_tril'")
|
||||
|
||||
def _validate_out_keys(self, keys: List[str]):
|
||||
valid_keys = {"loc", "scale", "scale_tril"}
|
||||
if not set(keys).issubset(valid_keys):
|
||||
raise ValueError(f"Invalid out_keys: {keys}. Must be a subset of {valid_keys}")
|
||||
if "loc" not in keys:
|
||||
raise ValueError("'loc' must be included in out_keys")
|
||||
if "scale" not in keys and "scale_tril" not in keys:
|
||||
raise ValueError("Either 'scale' or 'scale_tril' must be included in out_keys")
|
||||
|
||||
@abstractmethod
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def forward(self, tensordict):
|
||||
policy_params = {}
|
||||
old_policy_params = {}
|
||||
|
||||
for key in self.in_keys:
|
||||
if key not in tensordict:
|
||||
raise KeyError(f"Key '{key}' not found in tensordict. Available keys: {tensordict.keys()}")
|
||||
|
||||
if key.startswith("old_"):
|
||||
old_policy_params[key[4:]] = tensordict[key]
|
||||
else:
|
||||
policy_params[key] = tensordict[key]
|
||||
|
||||
projected_params = self.project(policy_params, old_policy_params)
|
||||
return projected_params
|
||||
|
||||
def _calc_covariance(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
if not self.full_cov:
|
||||
return torch.diag_embed(params["scale"].pow(2))
|
||||
else:
|
||||
return torch.matmul(params["scale_tril"], params["scale_tril"].transpose(-1, -2))
|
||||
|
||||
def _calc_scale_or_scale_tril(self, cov: torch.Tensor) -> torch.Tensor:
|
||||
if not self.full_cov:
|
||||
return torch.sqrt(cov.diagonal(dim1=-2, dim2=-1))
|
||||
else:
|
||||
return torch.linalg.cholesky(cov)
|
||||
def forward(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
return self.project(policy_params, old_policy_params)
|
@ -1,34 +1,33 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
from typing import Dict
|
||||
|
||||
class FrobeniusProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_chol = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
cov = self._calc_covariance(policy_params)
|
||||
old_cov = self._calc_covariance(old_policy_params)
|
||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_chol, old_chol.transpose(-1, -2))
|
||||
|
||||
mean_part, cov_part = self._gaussian_frobenius((mean, cov), (old_mean, old_cov))
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_cov = self._cov_projection(cov, old_cov, cov_part)
|
||||
|
||||
scale_or_scale_tril = self._calc_scale_or_scale_tril(proj_cov)
|
||||
return {"loc": proj_mean, self.out_keys[1]: scale_or_scale_tril}
|
||||
proj_chol = torch.linalg.cholesky(proj_cov)
|
||||
return {"loc": proj_mean, "scale_tril": proj_chol}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = policy_params["loc"]
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
mean, chol = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_chol = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
|
||||
cov = self._calc_covariance(policy_params)
|
||||
proj_cov = self._calc_covariance(proj_policy_params)
|
||||
cov = torch.matmul(chol, chol.transpose(-1, -2))
|
||||
proj_cov = torch.matmul(proj_chol, proj_chol.transpose(-1, -2))
|
||||
|
||||
mean_diff = torch.sum(torch.square(mean - proj_mean), dim=-1)
|
||||
cov_diff = torch.sum(torch.square(cov - proj_cov), dim=(-2, -1))
|
||||
|
@ -3,8 +3,8 @@ from .base_projection import BaseProjection
|
||||
from typing import Dict
|
||||
|
||||
class IdentityProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
return policy_params
|
||||
|
@ -2,7 +2,6 @@ import torch
|
||||
import cpp_projection
|
||||
import numpy as np
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
from typing import Dict, Tuple, Any
|
||||
|
||||
MAX_EVAL = 1000
|
||||
@ -11,65 +10,57 @@ def get_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
|
||||
class KLProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, is_diag: bool = True, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.is_diag = is_diag
|
||||
self.contextual_std = contextual_std
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||
old_mean, old_scale_or_tril = old_policy_params["loc"], old_policy_params[self.in_keys[1]]
|
||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_std = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
mean_part, cov_part = self._gaussian_kl((mean, scale_or_tril), (old_mean, old_scale_or_tril))
|
||||
mean_part, cov_part = self._gaussian_kl((mean, std), (old_mean, old_std))
|
||||
|
||||
if not self.contextual_std:
|
||||
scale_or_tril = scale_or_tril[:1]
|
||||
old_scale_or_tril = old_scale_or_tril[:1]
|
||||
std = std[:1]
|
||||
old_std = old_std[:1]
|
||||
cov_part = cov_part[:1]
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_scale_or_tril = self._cov_projection(scale_or_tril, old_scale_or_tril, cov_part)
|
||||
proj_std = self._cov_projection(std, old_std, cov_part)
|
||||
|
||||
if not self.contextual_std:
|
||||
proj_scale_or_tril = proj_scale_or_tril.expand(mean.shape[0], *proj_scale_or_tril.shape[1:])
|
||||
proj_std = proj_std.expand(mean.shape[0], -1, -1)
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_tril}
|
||||
return {"loc": proj_mean, "scale_tril": proj_std}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean, scale_or_tril = policy_params["loc"], policy_params[self.in_keys[1]]
|
||||
proj_mean, proj_scale_or_tril = proj_policy_params["loc"], proj_policy_params[self.out_keys[1]]
|
||||
kl = sum(self._gaussian_kl((mean, scale_or_tril), (proj_mean, proj_scale_or_tril)))
|
||||
mean, std = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_std = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
kl = sum(self._gaussian_kl((mean, std), (proj_mean, proj_std)))
|
||||
return kl.mean() * self.trust_region_coeff
|
||||
|
||||
def _gaussian_kl(self, p: Tuple[torch.Tensor, torch.Tensor], q: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, scale_or_tril = p
|
||||
mean_other, scale_or_tril_other = q
|
||||
mean, std = p
|
||||
mean_other, std_other = q
|
||||
k = mean.shape[-1]
|
||||
|
||||
maha_part = 0.5 * self._maha(mean, mean_other, scale_or_tril_other)
|
||||
maha_part = 0.5 * self._maha(mean, mean_other, std_other)
|
||||
|
||||
det_term = self._log_determinant(scale_or_tril)
|
||||
det_term_other = self._log_determinant(scale_or_tril_other)
|
||||
|
||||
if self.full_cov:
|
||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(scale_or_tril_other, scale_or_tril, upper=False))
|
||||
else:
|
||||
trace_part = torch.sum((scale_or_tril / scale_or_tril_other) ** 2, dim=-1)
|
||||
det_term = self._log_determinant(std)
|
||||
det_term_other = self._log_determinant(std_other)
|
||||
|
||||
trace_part = self._torch_batched_trace_square(torch.linalg.solve_triangular(std_other, std, upper=False))
|
||||
cov_part = 0.5 * (trace_part - k + det_term_other - det_term)
|
||||
|
||||
return maha_part, cov_part
|
||||
|
||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||
def _maha(self, x: torch.Tensor, y: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||
diff = x - y
|
||||
if self.full_cov:
|
||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), scale_or_tril, upper=False)[0].squeeze(-1)), dim=-1)
|
||||
else:
|
||||
return torch.sum(torch.square(diff / scale_or_tril), dim=-1)
|
||||
return torch.sum(torch.square(torch.triangular_solve(diff.unsqueeze(-1), std, upper=False)[0].squeeze(-1)), dim=-1)
|
||||
|
||||
def _log_determinant(self, scale_or_tril: torch.Tensor) -> torch.Tensor:
|
||||
if self.full_cov:
|
||||
return 2 * torch.log(scale_or_tril.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||
else:
|
||||
return 2 * torch.log(scale_or_tril).sum(-1)
|
||||
def _log_determinant(self, std: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * torch.log(std.diagonal(dim1=-2, dim2=-1)).sum(-1)
|
||||
|
||||
def _torch_batched_trace_square(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x.pow(2), dim=(-2, -1))
|
||||
@ -77,45 +68,49 @@ class KLProjection(BaseProjection):
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
return old_mean + (mean - old_mean) * torch.sqrt(self.mean_bound / (mean_part + 1e-8)).unsqueeze(-1)
|
||||
|
||||
def _cov_projection(self, scale_or_tril: torch.Tensor, old_scale_or_tril: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
if self.full_cov:
|
||||
cov = torch.matmul(scale_or_tril, scale_or_tril.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_scale_or_tril, old_scale_or_tril.transpose(-1, -2))
|
||||
else:
|
||||
cov = scale_or_tril.pow(2)
|
||||
old_cov = old_scale_or_tril.pow(2)
|
||||
def _cov_projection(self, std: torch.Tensor, old_std: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
cov = torch.matmul(std, std.transpose(-1, -2))
|
||||
old_cov = torch.matmul(old_std, old_std.transpose(-1, -2))
|
||||
|
||||
if self.is_diag:
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_scale_or_tril = torch.zeros_like(scale_or_tril)
|
||||
proj_scale_or_tril[~mask] = scale_or_tril[~mask]
|
||||
|
||||
proj_std = torch.zeros_like(std)
|
||||
proj_std[~mask] = std[~mask]
|
||||
try:
|
||||
if mask.any():
|
||||
if self.full_cov:
|
||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, scale_or_tril.detach(), old_scale_or_tril, self.cov_bound)
|
||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||
if is_invalid.any():
|
||||
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_scale_or_tril[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||
failed_mask = failed_mask.bool()
|
||||
if failed_mask.any():
|
||||
proj_scale_or_tril[failed_mask] = old_scale_or_tril[failed_mask]
|
||||
else:
|
||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov, old_cov, self.cov_bound)
|
||||
proj_cov = KLProjectionGradFunctionDiagCovOnly.apply(cov.diagonal(dim1=-2, dim2=-1),
|
||||
old_cov.diagonal(dim1=-2, dim2=-1),
|
||||
self.cov_bound)
|
||||
is_invalid = (proj_cov.mean(dim=-1).isnan() | proj_cov.mean(dim=-1).isinf() | (proj_cov.min(dim=-1).values < 0)) & mask
|
||||
if is_invalid.any():
|
||||
proj_scale_or_tril[is_invalid] = old_scale_or_tril[is_invalid]
|
||||
proj_std[is_invalid] = old_std[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_scale_or_tril[mask] = proj_cov[mask].sqrt()
|
||||
proj_std[mask] = proj_cov[mask].sqrt().diag_embed()
|
||||
except Exception as e:
|
||||
proj_std = old_std
|
||||
else:
|
||||
try:
|
||||
mask = cov_part > self.cov_bound
|
||||
proj_std = torch.zeros_like(std)
|
||||
proj_std[~mask] = std[~mask]
|
||||
if mask.any():
|
||||
proj_cov = KLProjectionGradFunctionCovOnly.apply(cov, std.detach(), old_std, self.cov_bound)
|
||||
is_invalid = proj_cov.mean([-2, -1]).isnan() & mask
|
||||
if is_invalid.any():
|
||||
proj_std[is_invalid] = old_std[is_invalid]
|
||||
mask &= ~is_invalid
|
||||
proj_std[mask], failed_mask = torch.linalg.cholesky_ex(proj_cov[mask])
|
||||
failed_mask = failed_mask.bool()
|
||||
if failed_mask.any():
|
||||
proj_std[failed_mask] = old_std[failed_mask]
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.error('Projection failed, taking old scale_or_tril for projection.')
|
||||
print("Projection failed, taking old scale_or_tril for projection.")
|
||||
proj_scale_or_tril = old_scale_or_tril
|
||||
logging.error('Projection failed, taking old cholesky for projection.')
|
||||
print("Projection failed, taking old cholesky for projection.")
|
||||
proj_std = old_std
|
||||
raise e
|
||||
|
||||
return proj_scale_or_tril
|
||||
return proj_std
|
||||
|
||||
|
||||
class KLProjectionGradFunctionCovOnly(torch.autograd.Function):
|
||||
|
@ -1,86 +1,56 @@
|
||||
import torch
|
||||
from .base_projection import BaseProjection
|
||||
from tensordict.nn import TensorDictModule
|
||||
from typing import Dict, Tuple
|
||||
|
||||
def scale_tril_to_sqrt(scale_tril: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
'Converts' scale_tril to scale_sqrt.
|
||||
|
||||
For Wasserstein distance, we need the matrix square root, not the Cholesky decomposition.
|
||||
But since both are lower triangular, we can treat the Cholesky decomposition as if it were the matrix square root.
|
||||
"""
|
||||
return scale_tril
|
||||
|
||||
def gaussian_wasserstein_commutative(policy, p: Tuple[torch.Tensor, torch.Tensor],
|
||||
q: Tuple[torch.Tensor, torch.Tensor], scale_prec=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
mean, scale_or_sqrt = p
|
||||
mean_other, scale_or_sqrt_other = q
|
||||
mean, sqrt = p
|
||||
mean_other, sqrt_other = q
|
||||
|
||||
mean_part = torch.sum(torch.square(mean - mean_other), dim=-1)
|
||||
|
||||
if scale_or_sqrt.dim() == mean.dim(): # Diagonal case
|
||||
cov = scale_or_sqrt.pow(2)
|
||||
cov_other = scale_or_sqrt_other.pow(2)
|
||||
cov = torch.matmul(sqrt, sqrt.transpose(-1, -2))
|
||||
cov_other = torch.matmul(sqrt_other, sqrt_other.transpose(-1, -2))
|
||||
|
||||
if scale_prec:
|
||||
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||
sqrt_inv_other = 1 / scale_or_sqrt_other
|
||||
c = sqrt_inv_other.pow(2) * cov
|
||||
cov_part = torch.sum(identity + c - 2 * sqrt_inv_other * scale_or_sqrt, dim=-1)
|
||||
identity = torch.eye(mean.shape[-1], dtype=sqrt.dtype, device=sqrt.device)
|
||||
sqrt_inv_other = torch.linalg.solve(sqrt_other, identity)
|
||||
c = sqrt_inv_other @ cov @ sqrt_inv_other
|
||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ sqrt)
|
||||
else:
|
||||
cov_part = torch.sum(cov_other + cov - 2 * scale_or_sqrt_other * scale_or_sqrt, dim=-1)
|
||||
else: # Full covariance case
|
||||
# Note: scale_or_sqrt is treated as the matrix square root, not Cholesky decomposition
|
||||
cov = torch.matmul(scale_or_sqrt, scale_or_sqrt.transpose(-1, -2))
|
||||
cov_other = torch.matmul(scale_or_sqrt_other, scale_or_sqrt_other.transpose(-1, -2))
|
||||
if scale_prec:
|
||||
identity = torch.eye(mean.shape[-1], dtype=scale_or_sqrt.dtype, device=scale_or_sqrt.device)
|
||||
sqrt_inv_other = torch.linalg.solve(scale_or_sqrt_other, identity)
|
||||
c = sqrt_inv_other @ cov @ sqrt_inv_other.transpose(-1, -2)
|
||||
cov_part = torch.trace(identity + c - 2 * sqrt_inv_other @ scale_or_sqrt)
|
||||
else:
|
||||
cov_part = torch.trace(cov_other + cov - 2 * scale_or_sqrt_other @ scale_or_sqrt)
|
||||
cov_part = torch.trace(cov_other + cov - 2 * sqrt_other @ sqrt)
|
||||
|
||||
return mean_part, cov_part
|
||||
|
||||
class WassersteinProjection(BaseProjection):
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False, contextual_std: bool = True):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound, contextual_std=contextual_std)
|
||||
def __init__(self, in_keys: list[str], out_keys: list[str], trust_region_coeff: float = 1.0, mean_bound: float = 0.01, cov_bound: float = 0.01, scale_prec: bool = False):
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys, trust_region_coeff=trust_region_coeff, mean_bound=mean_bound, cov_bound=cov_bound)
|
||||
self.scale_prec = scale_prec
|
||||
|
||||
def project(self, policy_params: Dict[str, torch.Tensor], old_policy_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
mean = policy_params["loc"]
|
||||
old_mean = old_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||
old_scale_or_sqrt = scale_tril_to_sqrt(old_policy_params[self.in_keys[1]])
|
||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
||||
old_mean, old_sqrt = old_policy_params["loc"], old_policy_params["scale_tril"]
|
||||
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (old_mean, old_scale_or_sqrt), self.scale_prec)
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (old_mean, old_sqrt), self.scale_prec)
|
||||
|
||||
proj_mean = self._mean_projection(mean, old_mean, mean_part)
|
||||
proj_scale_or_sqrt = self._cov_projection(scale_or_sqrt, old_scale_or_sqrt, cov_part)
|
||||
proj_sqrt = self._cov_projection(sqrt, old_sqrt, cov_part)
|
||||
|
||||
return {"loc": proj_mean, self.out_keys[1]: proj_scale_or_sqrt}
|
||||
return {"loc": proj_mean, "scale_tril": proj_sqrt}
|
||||
|
||||
def get_trust_region_loss(self, policy_params: Dict[str, torch.Tensor], proj_policy_params: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = policy_params["loc"]
|
||||
proj_mean = proj_policy_params["loc"]
|
||||
scale_or_sqrt = scale_tril_to_sqrt(policy_params[self.in_keys[1]])
|
||||
proj_scale_or_sqrt = scale_tril_to_sqrt(proj_policy_params[self.out_keys[1]])
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, scale_or_sqrt), (proj_mean, proj_scale_or_sqrt), self.scale_prec)
|
||||
mean, sqrt = policy_params["loc"], policy_params["scale_tril"]
|
||||
proj_mean, proj_sqrt = proj_policy_params["loc"], proj_policy_params["scale_tril"]
|
||||
mean_part, cov_part = gaussian_wasserstein_commutative(None, (mean, sqrt), (proj_mean, proj_sqrt), self.scale_prec)
|
||||
w2 = mean_part + cov_part
|
||||
return w2.mean() * self.trust_region_coeff
|
||||
|
||||
def _mean_projection(self, mean: torch.Tensor, old_mean: torch.Tensor, mean_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = mean - old_mean
|
||||
norm = torch.sqrt(mean_part)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm.unsqueeze(-1), mean)
|
||||
norm = torch.norm(diff, dim=-1, keepdim=True)
|
||||
return torch.where(norm > self.mean_bound, old_mean + diff * self.mean_bound / norm, mean)
|
||||
|
||||
def _cov_projection(self, scale_or_sqrt: torch.Tensor, old_scale_or_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
if scale_or_sqrt.dim() == old_scale_or_sqrt.dim() == 2: # Diagonal case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
norm = torch.sqrt(cov_part)
|
||||
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm.unsqueeze(-1), scale_or_sqrt)
|
||||
else: # Full covariance case
|
||||
diff = scale_or_sqrt - old_scale_or_sqrt
|
||||
def _cov_projection(self, sqrt: torch.Tensor, old_sqrt: torch.Tensor, cov_part: torch.Tensor) -> torch.Tensor:
|
||||
diff = sqrt - old_sqrt
|
||||
norm = torch.norm(diff, dim=(-2, -1), keepdim=True)
|
||||
return torch.where(norm > self.cov_bound, old_scale_or_sqrt + diff * self.cov_bound / norm, scale_or_sqrt)
|
||||
return torch.where(norm > self.cov_bound, old_sqrt + diff * self.cov_bound / norm, sqrt)
|
@ -33,27 +33,23 @@ def is_discrete_space(action_space):
|
||||
raise ValueError(f"Unsupported action space type: {type(action_space)}")
|
||||
|
||||
def get_space_shape(action_space):
|
||||
discrete_types = (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||
continuous_types = (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||
|
||||
if gym_available:
|
||||
discrete_types += (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types += (GymBox,)
|
||||
discrete_types = (GymDiscrete, GymMultiDiscrete, GymMultiBinary)
|
||||
continuous_types = (GymBox,)
|
||||
else:
|
||||
discrete_types = ()
|
||||
continuous_types = ()
|
||||
|
||||
discrete_types += (GymnasiumDiscrete, GymnasiumMultiDiscrete, GymnasiumMultiBinary,
|
||||
DiscreteTensorSpec, OneHotDiscreteTensorSpec, MultiDiscreteTensorSpec, BinaryDiscreteTensorSpec)
|
||||
continuous_types += (GymnasiumBox, BoundedTensorSpec, UnboundedContinuousTensorSpec)
|
||||
|
||||
if isinstance(action_space, discrete_types):
|
||||
if isinstance(action_space, (GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||
if isinstance(action_space, (GymDiscrete, GymnasiumDiscrete, DiscreteTensorSpec, OneHotDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, (GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||
elif isinstance(action_space, (GymMultiDiscrete, GymnasiumMultiDiscrete, MultiDiscreteTensorSpec)):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, (GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif gym_available:
|
||||
if isinstance(action_space, GymDiscrete):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, GymMultiDiscrete):
|
||||
return (sum(action_space.nvec),)
|
||||
elif isinstance(action_space, GymMultiBinary):
|
||||
elif isinstance(action_space, (GymMultiBinary, GymnasiumMultiBinary, BinaryDiscreteTensorSpec)):
|
||||
return (action_space.n,)
|
||||
elif isinstance(action_space, continuous_types):
|
||||
return action_space.shape
|
||||
|
@ -3,15 +3,12 @@ import numpy as np
|
||||
from fancy_rl import PPO
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('LunarLander-v2', continuous=True)
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_ppo_instantiation():
|
||||
ppo = PPO(simple_env)
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
def test_ppo_instantiation_from_str():
|
||||
ppo = PPO('CartPole-v1')
|
||||
ppo = PPO("CartPole-v1")
|
||||
assert isinstance(ppo, PPO)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@ -22,7 +19,7 @@ def test_ppo_instantiation_from_str():
|
||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
|
||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
|
||||
ppo = PPO(
|
||||
simple_env,
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
@ -37,42 +34,26 @@ def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_siz
|
||||
assert ppo.gamma == gamma
|
||||
assert ppo.clip_range == clip_range
|
||||
|
||||
def test_ppo_predict():
|
||||
ppo = PPO(simple_env)
|
||||
env = ppo.make_env()
|
||||
obs, _ = env.reset()
|
||||
def test_ppo_predict(simple_env):
|
||||
ppo = PPO("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = ppo.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == env.action_space.shape
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_ppo_learn():
|
||||
ppo = PPO(simple_env, n_steps=64, batch_size=32)
|
||||
env = ppo.make_env()
|
||||
ppo = PPO("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = ppo.predict(obs)
|
||||
obs, reward, done, truncated, _ = env.step(action)
|
||||
next_obs, reward, done, truncated, _ = env.step(action)
|
||||
ppo.store_transition(obs, action, reward, done, next_obs)
|
||||
obs = next_obs
|
||||
if done or truncated:
|
||||
obs, _ = env.reset()
|
||||
|
||||
def test_ppo_training():
|
||||
ppo = PPO(simple_env, total_timesteps=10000)
|
||||
env = ppo.make_env()
|
||||
|
||||
initial_performance = evaluate_policy(ppo, env)
|
||||
ppo.train()
|
||||
final_performance = evaluate_policy(ppo, env)
|
||||
|
||||
assert final_performance > initial_performance, "PPO should improve performance after training"
|
||||
|
||||
def evaluate_policy(policy, env, n_eval_episodes=10):
|
||||
total_reward = 0
|
||||
for _ in range(n_eval_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _ = policy.predict(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
total_reward += reward
|
||||
done = terminated or truncated
|
||||
return total_reward / n_eval_episodes
|
||||
loss = ppo.learn()
|
||||
assert isinstance(loss, dict)
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
@ -3,15 +3,12 @@ import numpy as np
|
||||
from fancy_rl import TRPL
|
||||
import gymnasium as gym
|
||||
|
||||
@pytest.fixture
|
||||
def simple_env():
|
||||
return gym.make('LunarLander-v2', continuous=True)
|
||||
return gym.make('CartPole-v1')
|
||||
|
||||
def test_trpl_instantiation():
|
||||
trpl = TRPL(simple_env)
|
||||
assert isinstance(trpl, TRPL)
|
||||
|
||||
def test_trpl_instantiation_from_str():
|
||||
trpl = TRPL('MountainCarContinuous-v0')
|
||||
trpl = TRPL("CartPole-v1")
|
||||
assert isinstance(trpl, TRPL)
|
||||
|
||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
|
||||
@ -22,7 +19,7 @@ def test_trpl_instantiation_from_str():
|
||||
@pytest.mark.parametrize("trust_region_bound_cov", [0.0005, 0.001])
|
||||
def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_size, gamma, trust_region_bound_mean, trust_region_bound_cov):
|
||||
trpl = TRPL(
|
||||
simple_env,
|
||||
"CartPole-v1",
|
||||
learning_rate=learning_rate,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
@ -37,17 +34,16 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
|
||||
assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
|
||||
assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
|
||||
|
||||
def test_trpl_predict():
|
||||
trpl = TRPL(simple_env)
|
||||
env = trpl.make_env()
|
||||
obs, _ = env.reset()
|
||||
def test_trpl_predict(simple_env):
|
||||
trpl = TRPL("CartPole-v1")
|
||||
obs, _ = simple_env.reset()
|
||||
action, _ = trpl.predict(obs)
|
||||
assert isinstance(action, np.ndarray)
|
||||
assert action.shape == env.action_space.shape
|
||||
assert action.shape == simple_env.action_space.shape
|
||||
|
||||
def test_trpl_learn():
|
||||
trpl = TRPL(simple_env, n_steps=64, batch_size=32)
|
||||
env = trpl.make_env()
|
||||
trpl = TRPL("CartPole-v1", n_steps=64, batch_size=32)
|
||||
env = gym.make("CartPole-v1")
|
||||
obs, _ = env.reset()
|
||||
for _ in range(64):
|
||||
action, _ = trpl.predict(obs)
|
||||
@ -62,13 +58,12 @@ def test_trpl_learn():
|
||||
assert "policy_loss" in loss
|
||||
assert "value_loss" in loss
|
||||
|
||||
def test_trpl_training():
|
||||
trpl = TRPL(simple_env, total_timesteps=10000)
|
||||
env = trpl.make_env()
|
||||
def test_trpl_training(simple_env):
|
||||
trpl = TRPL("CartPole-v1", total_timesteps=10000)
|
||||
|
||||
initial_performance = evaluate_policy(trpl, env)
|
||||
initial_performance = evaluate_policy(trpl, simple_env)
|
||||
trpl.train()
|
||||
final_performance = evaluate_policy(trpl, env)
|
||||
final_performance = evaluate_policy(trpl, simple_env)
|
||||
|
||||
assert final_performance > initial_performance, "TRPL should improve performance after training"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user