CHanges to work with new version of metastable baselines 2

This commit is contained in:
Dominik Moritz Roth 2024-01-26 12:01:11 +01:00
parent a5309e0fb8
commit 8a438e275f

View File

@ -13,7 +13,7 @@ class BaseProjectionLayer(object):
mean_bound: float = 0.03,
cov_bound: float = 1e-3,
trust_region_coeff: float = 1.0,
scale_prec: bool = True,
scale_prec: bool = False,
do_entropy_proj: bool = False,
entropy_eq: bool = False,
entropy_first: bool = False,
@ -26,6 +26,10 @@ class BaseProjectionLayer(object):
self.scale_prec = scale_prec
self.mean_eq = False
assert not entropy_eq, 'Sorry pal; thats actually not implemented yet.'
assert not entropy_first, 'Sorry pal; thats actually not implemented yet.'
assert not do_entropy_proj, 'Sorry pal; thats actually not implemented yet.'
self.entropy_first = entropy_first
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
@ -63,6 +67,7 @@ class BaseProjectionLayer(object):
####################################################################################################################
# entropy projection in the end
if not self.do_entropy_proj or self.entropy_first:
return new_p
@ -81,7 +86,7 @@ class BaseProjectionLayer(object):
Returns:
projected_dist, old_dist (from rollouts)
"""
old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps)
old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp)
return self(dist, old_distribution, **kwargs), old_distribution
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):