CHanges to work with new version of metastable baselines 2
This commit is contained in:
parent
a5309e0fb8
commit
8a438e275f
@ -13,7 +13,7 @@ class BaseProjectionLayer(object):
|
||||
mean_bound: float = 0.03,
|
||||
cov_bound: float = 1e-3,
|
||||
trust_region_coeff: float = 1.0,
|
||||
scale_prec: bool = True,
|
||||
scale_prec: bool = False,
|
||||
do_entropy_proj: bool = False,
|
||||
entropy_eq: bool = False,
|
||||
entropy_first: bool = False,
|
||||
@ -26,6 +26,10 @@ class BaseProjectionLayer(object):
|
||||
self.scale_prec = scale_prec
|
||||
self.mean_eq = False
|
||||
|
||||
assert not entropy_eq, 'Sorry pal; thats actually not implemented yet.'
|
||||
assert not entropy_first, 'Sorry pal; thats actually not implemented yet.'
|
||||
assert not do_entropy_proj, 'Sorry pal; thats actually not implemented yet.'
|
||||
|
||||
self.entropy_first = entropy_first
|
||||
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
|
||||
|
||||
@ -63,6 +67,7 @@ class BaseProjectionLayer(object):
|
||||
|
||||
####################################################################################################################
|
||||
# entropy projection in the end
|
||||
|
||||
if not self.do_entropy_proj or self.entropy_first:
|
||||
return new_p
|
||||
|
||||
@ -81,7 +86,7 @@ class BaseProjectionLayer(object):
|
||||
Returns:
|
||||
projected_dist, old_dist (from rollouts)
|
||||
"""
|
||||
old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps)
|
||||
old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp)
|
||||
return self(dist, old_distribution, **kwargs), old_distribution
|
||||
|
||||
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user