From 446eee5fa17e304bf5d82d82f2e1d29d5a2ca613 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 3 Apr 2024 18:25:51 +0200 Subject: [PATCH] Ensure bias terms on correct device --- priorConditionedAnnealing/pca.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 3589987..ead2f65 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -358,11 +358,13 @@ class StdNet(nn.Module): return self._ensure_positive_func(self.param * cont) elif self.par_strength == Par_Strength.CONT_DIAG: cont = self.net(x) - diag_chol = cont + self.bias + bias = self.bias.to(cont.device) + diag_chol = cont + bias return self._ensure_positive_func(diag_chol) elif self.par_strength == Par_Strength.CONT_FULL: cont = self.net(x) - return self._chol_from_flat(cont + self.bias) + bias = self.bias.to(device=cont.device) + return self._chol_from_flat(cont + bias) raise Exception()