Clean up and make cpp_projections optional
This commit is contained in:
parent
7538599f74
commit
a5309e0fb8
1
metastable_projections/__init__.py
Normal file
1
metastable_projections/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .projections import *
|
@ -1,5 +1,13 @@
|
|||||||
#TODO: License or such
|
#TODO: License or such
|
||||||
from .base_projection_layer import BaseProjectionLayer
|
from .base_projection_layer import BaseProjectionLayer
|
||||||
|
from .identity_projection_layer import IdentityProjectionLayer
|
||||||
from .frob_projection_layer import FrobeniusProjectionLayer
|
from .frob_projection_layer import FrobeniusProjectionLayer
|
||||||
from .kl_projection_layer import KLProjectionLayer
|
|
||||||
from .w2_projection_layer import WassersteinProjectionLayer
|
from .w2_projection_layer import WassersteinProjectionLayer
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cpp_projection
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print('[MSB] ITPAL is not installed; KL projections not avaible.')
|
||||||
|
from .base_projection_layer import ExceptionProjectionLayer as KLProjectionLayer
|
||||||
|
else:
|
||||||
|
from .kl_projection_layer import KLProjectionLayer
|
@ -9,7 +9,6 @@ from ..misc.distTools import *
|
|||||||
|
|
||||||
|
|
||||||
class BaseProjectionLayer(object):
|
class BaseProjectionLayer(object):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
mean_bound: float = 0.03,
|
mean_bound: float = 0.03,
|
||||||
cov_bound: float = 1e-3,
|
cov_bound: float = 1e-3,
|
||||||
@ -30,10 +29,8 @@ class BaseProjectionLayer(object):
|
|||||||
self.entropy_first = entropy_first
|
self.entropy_first = entropy_first
|
||||||
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
|
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
|
||||||
|
|
||||||
def __call__(self, p, q, step, *args, **kwargs):
|
def __call__(self, p, q, **kwargs):
|
||||||
# TODO: self.entropy_schedule(self.initial_entropy, self.target_entropy, self.temperature, step) * p[0].new_ones(p[0].shape[0])
|
return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=None, **kwargs)
|
||||||
entropy_bound = 'lol'
|
|
||||||
return self._projection(p, q, eps=self.mean_bound, eps_cov=self.cov_bound, beta=entropy_bound, **kwargs)
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs):
|
def _projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, beta: th.Tensor, **kwargs):
|
||||||
@ -71,6 +68,22 @@ class BaseProjectionLayer(object):
|
|||||||
|
|
||||||
return self.entropy_proj(new_p, beta)
|
return self.entropy_proj(new_p, beta)
|
||||||
|
|
||||||
|
def project_from_rollouts(self, dist, rollout_data, **kwargs):
|
||||||
|
"""
|
||||||
|
Hook for implementing the specific trust region projection
|
||||||
|
Args:
|
||||||
|
p: current distribution
|
||||||
|
q: old distribution
|
||||||
|
eps: mean trust region bound
|
||||||
|
eps_cov: covariance trust region bound
|
||||||
|
**kwargs:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
projected_dist, old_dist (from rollouts)
|
||||||
|
"""
|
||||||
|
old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps)
|
||||||
|
return self(dist, old_distribution, **kwargs), old_distribution
|
||||||
|
|
||||||
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
"""
|
"""
|
||||||
Hook for implementing the specific trust region projection
|
Hook for implementing the specific trust region projection
|
||||||
@ -108,6 +121,23 @@ class BaseProjectionLayer(object):
|
|||||||
"""
|
"""
|
||||||
return kl_divergence(p, q)
|
return kl_divergence(p, q)
|
||||||
|
|
||||||
|
def new_dist_like(orig_p, mean, cov_cholesky):
|
||||||
|
assert isinstance(orig_p, Distribution)
|
||||||
|
p = orig_p.distribution
|
||||||
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Normal(mean, cov_cholesky)
|
||||||
|
elif isinstance(p, th.distributions.Independent):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Independent(
|
||||||
|
th.distributions.Normal(mean, cov_cholesky), 1)
|
||||||
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
|
mean, scale_tril=cov_cholesky)
|
||||||
|
else:
|
||||||
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
|
return p_out
|
||||||
|
|
||||||
def entropy_inequality_projection(p: th.distributions.Normal,
|
def entropy_inequality_projection(p: th.distributions.Normal,
|
||||||
beta: Union[float, th.Tensor]):
|
beta: Union[float, th.Tensor]):
|
||||||
@ -219,3 +249,10 @@ def mean_equality_projection(mean: th.Tensor, old_mean: th.Tensor, maha: th.Tens
|
|||||||
proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16)
|
proj_mean = (mean + omega * old_mean) / (1 + omega + 1e-16)
|
||||||
|
|
||||||
return proj_mean
|
return proj_mean
|
||||||
|
|
||||||
|
|
||||||
|
class ExceptionProjectionLayer(BaseProjectionLayer):
|
||||||
|
def __init__(self,
|
||||||
|
*args, **kwargs
|
||||||
|
):
|
||||||
|
raise Exception('To be able to use KL projections, ITPAL must be installed: https://github.com/ALRhub/ITPAL (Private Repo).')
|
@ -0,0 +1,5 @@
|
|||||||
|
from .base_projection_layer import BaseProjectionLayer
|
||||||
|
|
||||||
|
class IdentityProjectionLayer(BaseProjectionLayer):
|
||||||
|
def project_from_rollouts(self, dist, rollout_data, **kwargs):
|
||||||
|
return dist, dist
|
@ -138,7 +138,10 @@ class KLProjectionGradFunctionDiagCovOnly(th.autograd.Function):
|
|||||||
batch_shape, dim)
|
batch_shape, dim)
|
||||||
ctx.proj = p_op
|
ctx.proj = p_op
|
||||||
|
|
||||||
|
try:
|
||||||
proj_std = p_op.forward(eps, old_std_np, std_np)
|
proj_std = p_op.forward(eps, old_std_np, std_np)
|
||||||
|
except:
|
||||||
|
proj_std = std_np
|
||||||
|
|
||||||
return cov.new(proj_std)
|
return cov.new(proj_std)
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ from .base_projection_layer import BaseProjectionLayer, mean_projection
|
|||||||
from ..misc.norm import mahalanobis, _batch_trace
|
from ..misc.norm import mahalanobis, _batch_trace
|
||||||
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
from ..misc.distTools import get_diag_cov_vec, get_mean_and_chol, get_mean_and_sqrt, get_cov, new_dist_like_from_sqrt, has_diag_cov
|
||||||
|
|
||||||
|
from stable_baselines3.common.distributions import Distribution
|
||||||
|
|
||||||
|
|
||||||
class WassersteinProjectionLayer(BaseProjectionLayer):
|
class WassersteinProjectionLayer(BaseProjectionLayer):
|
||||||
"""
|
"""
|
||||||
@ -92,6 +94,24 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
|
|
||||||
return kl_loss * self.trust_region_coeff
|
return kl_loss * self.trust_region_coeff
|
||||||
|
|
||||||
|
def new_dist_like(orig_p, mean, cov_sqrt):
|
||||||
|
assert isinstance(orig_p, Distribution)
|
||||||
|
p = orig_p.distribution
|
||||||
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Normal(mean, cov_sqrt)
|
||||||
|
elif isinstance(p, th.distributions.Independent):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Independent(
|
||||||
|
th.distributions.Normal(mean, cov_sqrt), 1)
|
||||||
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
|
mean, scale_tril=cov_sqrt)
|
||||||
|
else:
|
||||||
|
raise Exception('Dist-Type not implemented (of sb3 dist)')
|
||||||
|
return p_out
|
||||||
|
|
||||||
|
|
||||||
def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]:
|
def gaussian_wasserstein_commutative(p, q, scale_prec=False) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user