diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index 6da8269..0162710 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -13,7 +13,7 @@ class BaseProjectionLayer(object): mean_bound: float = 0.03, cov_bound: float = 1e-3, trust_region_coeff: float = 1.0, - scale_prec: bool = True, + scale_prec: bool = False, do_entropy_proj: bool = False, entropy_eq: bool = False, entropy_first: bool = False, @@ -26,6 +26,10 @@ class BaseProjectionLayer(object): self.scale_prec = scale_prec self.mean_eq = False + assert not entropy_eq, 'Sorry pal; thats actually not implemented yet.' + assert not entropy_first, 'Sorry pal; thats actually not implemented yet.' + assert not do_entropy_proj, 'Sorry pal; thats actually not implemented yet.' + self.entropy_first = entropy_first self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection @@ -63,6 +67,7 @@ class BaseProjectionLayer(object): #################################################################################################################### # entropy projection in the end + if not self.do_entropy_proj or self.entropy_first: return new_p @@ -81,7 +86,7 @@ class BaseProjectionLayer(object): Returns: projected_dist, old_dist (from rollouts) """ - old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps) + old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp) return self(dist, old_distribution, **kwargs), old_distribution def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):