commit message
This commit is contained in:
parent
23236f18a4
commit
73a0271452
102
spdesic.py
102
spdesic.py
@ -2,8 +2,9 @@ import torch as th
|
|||||||
import geoopt
|
import geoopt
|
||||||
import pymanopt
|
import pymanopt
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import math
|
||||||
|
|
||||||
n = 3
|
n = 4
|
||||||
|
|
||||||
spd_alt = geoopt.SymmetricPositiveDefinite()
|
spd_alt = geoopt.SymmetricPositiveDefinite()
|
||||||
spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1)
|
spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1)
|
||||||
@ -15,31 +16,40 @@ shape = (n,n)
|
|||||||
eta = 1
|
eta = 1
|
||||||
pos = (eta)/(1+eta)
|
pos = (eta)/(1+eta)
|
||||||
s = 3.14159
|
s = 3.14159
|
||||||
d = 0.01
|
d = 0.1
|
||||||
eps = 0.001
|
eps = 0.001
|
||||||
|
|
||||||
num = 1024
|
num = 1024
|
||||||
|
|
||||||
|
# ignore these...
|
||||||
|
fev = 1
|
||||||
|
few = 1
|
||||||
|
|
||||||
blacklist = []
|
blacklist = []
|
||||||
blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann']
|
blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann']
|
||||||
#blacklist += ['linear_eigen', 'sqrt_eigen']
|
#blacklist += ['linear_eigen', 'sqrt_eigen']
|
||||||
#blacklist += ['scaled_eigen']
|
blacklist += ['scaled_eigen']
|
||||||
#blacklist += ['euclidean_sqrt']
|
#blacklist += ['euclidean_prec']
|
||||||
|
#blacklist += ['euclidean_sqrt', 'commutative_eigen']
|
||||||
|
|
||||||
|
|
||||||
def genRandSPDs(local=True):
|
def genRandSPDs(local=True):
|
||||||
if local:
|
if local:
|
||||||
a = spd.random_point()*s
|
a = spd.random_point()*s
|
||||||
#a = spd.random(shape)
|
#a = spd.random(shape)
|
||||||
# return a, a + spd.random(shape)*d
|
# return a, a + spd.random(shape)*d
|
||||||
return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta
|
return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(n)*eta
|
||||||
return spd.random(shape), spd.random(shape)
|
return spd.random(shape), spd.random(shape)
|
||||||
|
|
||||||
|
|
||||||
akku = 0
|
akku = 0
|
||||||
|
|
||||||
|
|
||||||
def calcErrors(a, b):
|
def calcErrors(a, b):
|
||||||
global akku
|
global akku
|
||||||
### eigen decomp
|
etaV, etaW = eta * fev, eta * few
|
||||||
|
|
||||||
|
# eigen decomp
|
||||||
ewa, eva = th.linalg.eigh(a)
|
ewa, eva = th.linalg.eigh(a)
|
||||||
ewb, evb = th.linalg.eigh(b)
|
ewb, evb = th.linalg.eigh(b)
|
||||||
|
|
||||||
@ -50,39 +60,49 @@ def calcErrors(a, b):
|
|||||||
# EVs flipped; try again...
|
# EVs flipped; try again...
|
||||||
return False
|
return False
|
||||||
|
|
||||||
### euclidean approx (also depends on eigendecomp for fair comparison)
|
# euclidean approx (also depends on eigendecomp for fair comparison)
|
||||||
ar = eva @ th.diag(ewa) @ eva.T
|
ar = eva @ th.diag(ewa) @ eva.T
|
||||||
br = evb @ th.diag(ewb) @ evb.T
|
br = evb @ th.diag(ewb) @ evb.T
|
||||||
|
|
||||||
emb = (ar + eta*br)/(1+eta)
|
emb = (ar + eta*br)/(1+eta)
|
||||||
|
|
||||||
### euclidean_sqrt
|
# euclidean_sqrt
|
||||||
asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T
|
asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T
|
||||||
bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T
|
bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T
|
||||||
|
|
||||||
mbsqrt = (asqrt + eta*bsqrt)/(1+eta)
|
mbsqrt = (asqrt + eta*bsqrt)/(1+eta)
|
||||||
sqrtmb = mbsqrt@mbsqrt
|
sqrtmb = mbsqrt@mbsqrt
|
||||||
|
|
||||||
### chol approx (also depends on eigendecomp for fair comparison)
|
# euclidean_prec
|
||||||
|
ainv = eva @ th.diag(1/ewa) @ eva.T
|
||||||
|
binv = evb @ th.diag(1/ewb) @ evb.T
|
||||||
|
|
||||||
|
if 'euclidean_prec' in blacklist:
|
||||||
|
precmb = ar
|
||||||
|
else:
|
||||||
|
precmb = th.inverse((ainv + eta*binv)/(1+eta))
|
||||||
|
|
||||||
|
# chol approx (also depends on eigendecomp for fair comparison)
|
||||||
la = th.linalg.cholesky(ar)
|
la = th.linalg.cholesky(ar)
|
||||||
lb = th.linalg.cholesky(br)
|
lb = th.linalg.cholesky(br)
|
||||||
|
|
||||||
cholmb_chol = (la + eta*lb)/(1+eta)
|
cholmb_chol = (la + eta*lb)/(1+eta)
|
||||||
cholmb = cholmb_chol@cholmb_chol.T
|
cholmb = cholmb_chol@cholmb_chol.T
|
||||||
|
|
||||||
### riemann
|
# riemann
|
||||||
if 'riemann' in blacklist:
|
if 'riemann' in blacklist:
|
||||||
riemannmb = ar
|
riemannmb = ar
|
||||||
else:
|
else:
|
||||||
riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float()
|
riemannmb = th.Tensor(
|
||||||
|
spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float()
|
||||||
|
|
||||||
### eigen approx
|
# eigen approx
|
||||||
|
|
||||||
# com
|
# com
|
||||||
ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva
|
ewmb, ecvmb = (ewa + etaW*ewb)/(1+etaW), eva
|
||||||
|
|
||||||
# lin
|
# lin
|
||||||
elvmb = (eva + eta*evb)/(1+eta)
|
elvmb = (eva + etaV*evb)/(1+etaV)
|
||||||
# not closed form, but stable gradients
|
# not closed form, but stable gradients
|
||||||
elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva))
|
elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva))
|
||||||
if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0:
|
if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0:
|
||||||
@ -95,11 +115,14 @@ def calcErrors(a, b):
|
|||||||
elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float()
|
elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float()
|
||||||
|
|
||||||
# ew_sqrt
|
# ew_sqrt
|
||||||
ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta)
|
ewsqrtmb_sqrt = (th.sqrt(ewa) + etaW*th.sqrt(ewb))/(1+etaW)
|
||||||
ewsqrtmb = ewsqrtmb_sqrt**2
|
ewsqrtmb = ewsqrtmb_sqrt**2
|
||||||
|
|
||||||
|
# ew_inv
|
||||||
|
ewinvmb = 1/((1/ewa + etaW*(1/ewb))/(1+etaW))
|
||||||
|
|
||||||
# ew scaling
|
# ew scaling
|
||||||
esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb
|
esvmb = ((eva@th.diag(ewa) + etaV*evb@th.diag(ewb))/(1+etaV))/ewmb
|
||||||
|
|
||||||
cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T
|
cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T
|
||||||
cmb = spd_alt.projx(cmb)
|
cmb = spd_alt.projx(cmb)
|
||||||
@ -110,6 +133,8 @@ def calcErrors(a, b):
|
|||||||
# THIS
|
# THIS
|
||||||
lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T
|
lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T
|
||||||
|
|
||||||
|
linvmb = elvmb @ th.diag(ewinvmb) @ elvmb.T
|
||||||
|
|
||||||
lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T
|
lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T
|
||||||
#lrmb = spd_alt.projx(lrmb)
|
#lrmb = spd_alt.projx(lrmb)
|
||||||
|
|
||||||
@ -120,7 +145,7 @@ def calcErrors(a, b):
|
|||||||
|
|
||||||
lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T
|
lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T
|
||||||
|
|
||||||
### checking
|
# checking
|
||||||
if True:
|
if True:
|
||||||
a = a.numpy() # Sigma_old
|
a = a.numpy() # Sigma_old
|
||||||
b = b.numpy() # Sigma
|
b = b.numpy() # Sigma
|
||||||
@ -132,10 +157,15 @@ def calcErrors(a, b):
|
|||||||
cmb = cmb.numpy() # eigen under commutative assumption
|
cmb = cmb.numpy() # eigen under commutative assumption
|
||||||
lmb = lmb.numpy() # eigen with linear basis interpolation
|
lmb = lmb.numpy() # eigen with linear basis interpolation
|
||||||
lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic
|
lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic
|
||||||
lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol)
|
# eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol)
|
||||||
lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW
|
lsqrtmb = lsqrtmb.numpy()
|
||||||
lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW
|
# eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW
|
||||||
|
lrsqrtmb = lrsqrtmb.numpy()
|
||||||
|
# eigen with scaled eigenbasis interpolation and sqrt interpol for EW
|
||||||
|
lssqrtmb = lssqrtmb.numpy()
|
||||||
smb = smb.numpy() # eigen with scaled interpolation
|
smb = smb.numpy() # eigen with scaled interpolation
|
||||||
|
precmb = precmb.numpy()
|
||||||
|
linvmb = linvmb.numpy()
|
||||||
|
|
||||||
# ground truth
|
# ground truth
|
||||||
tru_damb = dist(a, b)
|
tru_damb = dist(a, b)
|
||||||
@ -149,6 +179,9 @@ def calcErrors(a, b):
|
|||||||
# euclid_sqrt
|
# euclid_sqrt
|
||||||
sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b)
|
sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b)
|
||||||
|
|
||||||
|
# prec
|
||||||
|
prec_damb = dist(a, precmb) + dist(precmb, b)
|
||||||
|
|
||||||
# chol
|
# chol
|
||||||
chol_damb = dist(a, cholmb) + dist(cholmb, b)
|
chol_damb = dist(a, cholmb) + dist(cholmb, b)
|
||||||
|
|
||||||
@ -161,6 +194,9 @@ def calcErrors(a, b):
|
|||||||
else:
|
else:
|
||||||
ewl_damb = dist(a, lmb) + dist(lmb, b)
|
ewl_damb = dist(a, lmb) + dist(lmb, b)
|
||||||
|
|
||||||
|
# ew inv
|
||||||
|
ewinv_damb = dist(a, linvmb) + dist(linvmb, b)
|
||||||
|
|
||||||
# ew sqrt
|
# ew sqrt
|
||||||
if 'sqrt_eigen' in blacklist:
|
if 'sqrt_eigen' in blacklist:
|
||||||
ewlsqrt_damb = 0
|
ewlsqrt_damb = 0
|
||||||
@ -179,24 +215,26 @@ def calcErrors(a, b):
|
|||||||
# ew sca
|
# ew sca
|
||||||
ews_damb = dist(a, smb) + dist(smb, b)
|
ews_damb = dist(a, smb) + dist(smb, b)
|
||||||
|
|
||||||
akku += dist(sqrtmb, lsqrtmb)/tru_damb
|
akku += tru_damb
|
||||||
|
|
||||||
return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb)
|
return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb), abs(prec_damb-tru_damb), abs(ewinv_damb-tru_damb)
|
||||||
# except:
|
# except:
|
||||||
# print('num issue')
|
# print('num issue')
|
||||||
# return 0, 0, 0, 0
|
# return 0, 0, 0, 0
|
||||||
|
|
||||||
|
|
||||||
def testSingle(local=True):
|
def testSingle(local=True):
|
||||||
a, b = genRandSPDs(local=local)
|
a, b = genRandSPDs(local=local)
|
||||||
return calcErrors(a, b)
|
return calcErrors(a, b)
|
||||||
|
|
||||||
|
|
||||||
def test(num=1024, local=True):
|
def test(num=1024, local=True):
|
||||||
euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs, prec_errs, ewinv_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
||||||
for i in tqdm(range(num)):
|
for i in tqdm(range(num)):
|
||||||
res = False
|
res = False
|
||||||
while res == False:
|
while res == False:
|
||||||
res = testSingle(local=local)
|
res = testSingle(local=local)
|
||||||
euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res
|
euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err, prec_err, ewinv_err = res
|
||||||
euc_errs += euc_err
|
euc_errs += euc_err
|
||||||
ewc_errs += ewc_err
|
ewc_errs += ewc_err
|
||||||
ewl_errs += ewl_err
|
ewl_errs += ewl_err
|
||||||
@ -208,19 +246,25 @@ def test(num=1024, local=True):
|
|||||||
ewlsqrt_errs += ewlsqrt_err
|
ewlsqrt_errs += ewlsqrt_err
|
||||||
rie_errs += rie_err
|
rie_errs += rie_err
|
||||||
ewlssqrt_errs += ewlssqrt_err
|
ewlssqrt_errs += ewlssqrt_err
|
||||||
return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num
|
prec_errs += prec_err
|
||||||
|
ewinv_errs += ewinv_err
|
||||||
|
return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num, prec_errs/num, ewinv_errs/num
|
||||||
|
|
||||||
names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen']
|
|
||||||
|
|
||||||
res = th.Tensor(test(num=num, local=True))/d*100
|
names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt',
|
||||||
|
'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen', 'euclidean_prec', 'inv_eigen']
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
res = th.Tensor(test(num=num, local=True))/(akku/num)*100
|
||||||
|
|
||||||
for n, r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False):
|
for n, r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False):
|
||||||
if not n in blacklist:
|
if not n in blacklist:
|
||||||
print(n+': '+'%.6f' % r+'%')
|
print(n+': '+'%.6f' % r+'%')
|
||||||
|
|
||||||
print('---')
|
|
||||||
print(str(akku/num*100) + '%')
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
# ---
|
# ---
|
||||||
#s = 3.14159
|
#s = 3.14159
|
||||||
|
Loading…
Reference in New Issue
Block a user