Ensure bias terms on correct device
This commit is contained in:
parent
6c83406492
commit
446eee5fa1
@ -358,11 +358,13 @@ class StdNet(nn.Module):
|
||||
return self._ensure_positive_func(self.param * cont)
|
||||
elif self.par_strength == Par_Strength.CONT_DIAG:
|
||||
cont = self.net(x)
|
||||
diag_chol = cont + self.bias
|
||||
bias = self.bias.to(cont.device)
|
||||
diag_chol = cont + bias
|
||||
return self._ensure_positive_func(diag_chol)
|
||||
elif self.par_strength == Par_Strength.CONT_FULL:
|
||||
cont = self.net(x)
|
||||
return self._chol_from_flat(cont + self.bias)
|
||||
bias = self.bias.to(device=cont.device)
|
||||
return self._chol_from_flat(cont + bias)
|
||||
|
||||
raise Exception()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user