diff --git a/priorConditionedAnnealing/pca.py b/priorConditionedAnnealing/pca.py index 3589987..ead2f65 100644 --- a/priorConditionedAnnealing/pca.py +++ b/priorConditionedAnnealing/pca.py @@ -358,11 +358,13 @@ class StdNet(nn.Module): return self._ensure_positive_func(self.param * cont) elif self.par_strength == Par_Strength.CONT_DIAG: cont = self.net(x) - diag_chol = cont + self.bias + bias = self.bias.to(cont.device) + diag_chol = cont + bias return self._ensure_positive_func(diag_chol) elif self.par_strength == Par_Strength.CONT_FULL: cont = self.net(x) - return self._chol_from_flat(cont + self.bias) + bias = self.bias.to(device=cont.device) + return self._chol_from_flat(cont + bias) raise Exception()