diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index 0162710..c7b9cb4 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -126,7 +126,7 @@ class BaseProjectionLayer(object): """ return kl_divergence(p, q) - def new_dist_like(orig_p, mean, cov_cholesky): + def new_dist_like(self, orig_p, mean, cov_cholesky): assert isinstance(orig_p, Distribution) p = orig_p.distribution if isinstance(p, th.distributions.Normal): diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index 7c99f59..5bb785c 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -94,7 +94,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): return kl_loss * self.trust_region_coeff - def new_dist_like(orig_p, mean, cov_sqrt): + def new_dist_like(self, orig_p, mean, cov_sqrt): assert isinstance(orig_p, Distribution) p = orig_p.distribution if isinstance(p, th.distributions.Normal):