Ensure bias terms on correct device
This commit is contained in:
parent
6c83406492
commit
446eee5fa1
@ -358,11 +358,13 @@ class StdNet(nn.Module):
|
|||||||
return self._ensure_positive_func(self.param * cont)
|
return self._ensure_positive_func(self.param * cont)
|
||||||
elif self.par_strength == Par_Strength.CONT_DIAG:
|
elif self.par_strength == Par_Strength.CONT_DIAG:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
diag_chol = cont + self.bias
|
bias = self.bias.to(cont.device)
|
||||||
|
diag_chol = cont + bias
|
||||||
return self._ensure_positive_func(diag_chol)
|
return self._ensure_positive_func(diag_chol)
|
||||||
elif self.par_strength == Par_Strength.CONT_FULL:
|
elif self.par_strength == Par_Strength.CONT_FULL:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
return self._chol_from_flat(cont + self.bias)
|
bias = self.bias.to(device=cont.device)
|
||||||
|
return self._chol_from_flat(cont + bias)
|
||||||
|
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user