From 7ad4858e8ac142a9e21612e2baebeb4bbe83fff8 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 26 Jan 2024 12:02:53 +0100 Subject: [PATCH] Bugfix: new_dist_like missing self in args --- metastable_projections/projections/base_projection_layer.py | 2 +- metastable_projections/projections/w2_projection_layer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/metastable_projections/projections/base_projection_layer.py b/metastable_projections/projections/base_projection_layer.py index 0162710..c7b9cb4 100644 --- a/metastable_projections/projections/base_projection_layer.py +++ b/metastable_projections/projections/base_projection_layer.py @@ -126,7 +126,7 @@ class BaseProjectionLayer(object): """ return kl_divergence(p, q) - def new_dist_like(orig_p, mean, cov_cholesky): + def new_dist_like(self, orig_p, mean, cov_cholesky): assert isinstance(orig_p, Distribution) p = orig_p.distribution if isinstance(p, th.distributions.Normal): diff --git a/metastable_projections/projections/w2_projection_layer.py b/metastable_projections/projections/w2_projection_layer.py index 7c99f59..5bb785c 100644 --- a/metastable_projections/projections/w2_projection_layer.py +++ b/metastable_projections/projections/w2_projection_layer.py @@ -94,7 +94,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): return kl_loss * self.trust_region_coeff - def new_dist_like(orig_p, mean, cov_sqrt): + def new_dist_like(self, orig_p, mean, cov_sqrt): assert isinstance(orig_p, Distribution) p = orig_p.distribution if isinstance(p, th.distributions.Normal):