Ensure bias terms on correct device

This commit is contained in:
Dominik Moritz Roth 2024-04-03 18:25:51 +02:00
parent 6c83406492
commit 446eee5fa1

View File

@ -358,11 +358,13 @@ class StdNet(nn.Module):
return self._ensure_positive_func(self.param * cont) return self._ensure_positive_func(self.param * cont)
elif self.par_strength == Par_Strength.CONT_DIAG: elif self.par_strength == Par_Strength.CONT_DIAG:
cont = self.net(x) cont = self.net(x)
diag_chol = cont + self.bias bias = self.bias.to(cont.device)
diag_chol = cont + bias
return self._ensure_positive_func(diag_chol) return self._ensure_positive_func(diag_chol)
elif self.par_strength == Par_Strength.CONT_FULL: elif self.par_strength == Par_Strength.CONT_FULL:
cont = self.net(x) cont = self.net(x)
return self._chol_from_flat(cont + self.bias) bias = self.bias.to(device=cont.device)
return self._chol_from_flat(cont + bias)
raise Exception() raise Exception()