Bugfix: new_dist_like missing self in args

This commit is contained in:
Dominik Moritz Roth 2024-01-26 12:02:53 +01:00
parent 8a438e275f
commit 7ad4858e8a
2 changed files with 2 additions and 2 deletions

View File

@ -126,7 +126,7 @@ class BaseProjectionLayer(object):
""" """
return kl_divergence(p, q) return kl_divergence(p, q)
def new_dist_like(orig_p, mean, cov_cholesky): def new_dist_like(self, orig_p, mean, cov_cholesky):
assert isinstance(orig_p, Distribution) assert isinstance(orig_p, Distribution)
p = orig_p.distribution p = orig_p.distribution
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):

View File

@ -94,7 +94,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer):
return kl_loss * self.trust_region_coeff return kl_loss * self.trust_region_coeff
def new_dist_like(orig_p, mean, cov_sqrt): def new_dist_like(self, orig_p, mean, cov_sqrt):
assert isinstance(orig_p, Distribution) assert isinstance(orig_p, Distribution)
p = orig_p.distribution p = orig_p.distribution
if isinstance(p, th.distributions.Normal): if isinstance(p, th.distributions.Normal):