Bugfix: new_dist_like missing self in args
This commit is contained in:
parent
8a438e275f
commit
7ad4858e8a
@ -126,7 +126,7 @@ class BaseProjectionLayer(object):
|
|||||||
"""
|
"""
|
||||||
return kl_divergence(p, q)
|
return kl_divergence(p, q)
|
||||||
|
|
||||||
def new_dist_like(orig_p, mean, cov_cholesky):
|
def new_dist_like(self, orig_p, mean, cov_cholesky):
|
||||||
assert isinstance(orig_p, Distribution)
|
assert isinstance(orig_p, Distribution)
|
||||||
p = orig_p.distribution
|
p = orig_p.distribution
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
@ -94,7 +94,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
|
|||||||
|
|
||||||
return kl_loss * self.trust_region_coeff
|
return kl_loss * self.trust_region_coeff
|
||||||
|
|
||||||
def new_dist_like(orig_p, mean, cov_sqrt):
|
def new_dist_like(self, orig_p, mean, cov_sqrt):
|
||||||
assert isinstance(orig_p, Distribution)
|
assert isinstance(orig_p, Distribution)
|
||||||
p = orig_p.distribution
|
p = orig_p.distribution
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
|
Loading…
Reference in New Issue
Block a user