CHanges to work with new version of metastable baselines 2
This commit is contained in:
parent
a5309e0fb8
commit
8a438e275f
@ -13,7 +13,7 @@ class BaseProjectionLayer(object):
|
|||||||
mean_bound: float = 0.03,
|
mean_bound: float = 0.03,
|
||||||
cov_bound: float = 1e-3,
|
cov_bound: float = 1e-3,
|
||||||
trust_region_coeff: float = 1.0,
|
trust_region_coeff: float = 1.0,
|
||||||
scale_prec: bool = True,
|
scale_prec: bool = False,
|
||||||
do_entropy_proj: bool = False,
|
do_entropy_proj: bool = False,
|
||||||
entropy_eq: bool = False,
|
entropy_eq: bool = False,
|
||||||
entropy_first: bool = False,
|
entropy_first: bool = False,
|
||||||
@ -26,6 +26,10 @@ class BaseProjectionLayer(object):
|
|||||||
self.scale_prec = scale_prec
|
self.scale_prec = scale_prec
|
||||||
self.mean_eq = False
|
self.mean_eq = False
|
||||||
|
|
||||||
|
assert not entropy_eq, 'Sorry pal; thats actually not implemented yet.'
|
||||||
|
assert not entropy_first, 'Sorry pal; thats actually not implemented yet.'
|
||||||
|
assert not do_entropy_proj, 'Sorry pal; thats actually not implemented yet.'
|
||||||
|
|
||||||
self.entropy_first = entropy_first
|
self.entropy_first = entropy_first
|
||||||
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
|
self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection
|
||||||
|
|
||||||
@ -63,6 +67,7 @@ class BaseProjectionLayer(object):
|
|||||||
|
|
||||||
####################################################################################################################
|
####################################################################################################################
|
||||||
# entropy projection in the end
|
# entropy projection in the end
|
||||||
|
|
||||||
if not self.do_entropy_proj or self.entropy_first:
|
if not self.do_entropy_proj or self.entropy_first:
|
||||||
return new_p
|
return new_p
|
||||||
|
|
||||||
@ -81,7 +86,7 @@ class BaseProjectionLayer(object):
|
|||||||
Returns:
|
Returns:
|
||||||
projected_dist, old_dist (from rollouts)
|
projected_dist, old_dist (from rollouts)
|
||||||
"""
|
"""
|
||||||
old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps)
|
old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp)
|
||||||
return self(dist, old_distribution, **kwargs), old_distribution
|
return self(dist, old_distribution, **kwargs), old_distribution
|
||||||
|
|
||||||
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user