Allow checking whether a dist is contextual
This commit is contained in:
parent
a8b9c63965
commit
ab1b269af9
@ -52,6 +52,11 @@ def has_diag_cov(p, numerical_check=True):
|
|||||||
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
|
||||||
|
|
||||||
|
|
||||||
|
def is_contextual(p):
|
||||||
|
# TODO: Implement for UniveralGaussianDist
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
|
||||||
if check_diag and not has_diag_cov(p):
|
if check_diag and not has_diag_cov(p):
|
||||||
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
|
||||||
|
Loading…
Reference in New Issue
Block a user