commit 6f43bfb7192581a4d16ab6138814f7651abe3893 Author: Dominik Roth Date: Mon Nov 28 23:07:21 2022 +0100 initial commit diff --git a/spdesic.py b/spdesic.py new file mode 100644 index 0000000..773bdfa --- /dev/null +++ b/spdesic.py @@ -0,0 +1,238 @@ +import torch as th +import geoopt +import pymanopt +from tqdm import tqdm + +n = 3 + +spd_alt = geoopt.SymmetricPositiveDefinite() +spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1) +so = pymanopt.manifolds.special_orthogonal_group.SpecialOrthogonalGroup(n, k=1) + +dist = spd.dist + +shape = (n,n) +eta = 1 +pos = (eta)/(1+eta) +s = 3.14159 +d = 0.01 +eps = 0.001 + +num = 1024 + + +blacklist = [] +blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann'] +#blacklist += ['linear_eigen', 'sqrt_eigen'] +#blacklist += ['scaled_eigen'] +#blacklist += ['euclidean_sqrt'] + +def genRandSPDs(local=True): + if local: + a = spd.random_point()*s + #a = spd.random(shape) + #return a, a + spd.random(shape)*d + return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta + return spd.random(shape), spd.random(shape) + +akku = 0 + +def calcErrors(a, b): + global akku + ### eigen decomp + ewa, eva = th.linalg.eigh(a) + ewb, evb = th.linalg.eigh(b) + + ewa, eva = ewa.real, eva.real + ewb, evb = ewb.real, evb.real + + if th.norm(eva-evb) > d*2: + # EVs flipped; try again... + return False + + ### euclidean approx (also depends on eigendecomp for fair comparison) + ar = eva @ th.diag(ewa) @ eva.T + br = evb @ th.diag(ewb) @ evb.T + + emb = (ar + eta*br)/(1+eta) + + ### euclidean_sqrt + asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T + bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T + + mbsqrt = (asqrt + eta*bsqrt)/(1+eta) + sqrtmb = mbsqrt@mbsqrt + + ### chol approx (also depends on eigendecomp for fair comparison) + la = th.linalg.cholesky(ar) + lb = th.linalg.cholesky(br) + + cholmb_chol = (la + eta*lb)/(1+eta) + cholmb = cholmb_chol@cholmb_chol.T + + ### riemann + if 'riemann' in blacklist: + riemannmb = ar + else: + riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float() + + ### eigen approx + + # com + ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva + + # lin + elvmb = (eva + eta*evb)/(1+eta) + # not closed form, but stable gradients + elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva)) + if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0: + elvmb = elvmb_retr + + # lin_riemann + if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist: + elrvmb = eva + else: + elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float() + + # ew_sqrt + ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta) + ewsqrtmb = ewsqrtmb_sqrt**2 + + # ew scaling + esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb + + cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T + cmb = spd_alt.projx(cmb) + + lmb = elvmb @ th.diag(ewmb) @ elvmb.T + #lmb = spd_alt.projx(lmb) + + # THIS + lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T + + lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T + #lrmb = spd_alt.projx(lrmb) + + lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T + + smb = esvmb @ th.diag(ewmb) @ esvmb.T + #smb = spd_alt.projx(smb) + + lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T + + ### checking + if True: + a = a.numpy() # Sigma_old + b = b.numpy() # Sigma + + emb = emb.numpy() # euclidean line + cholmb = cholmb.numpy() # line in chol space + sqrtmb = sqrtmb.numpy() # line in spq-matrix space + riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case) + cmb = cmb.numpy() # eigen under commutative assumption + lmb = lmb.numpy() # eigen with linear basis interpolation + lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic + lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol) + lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW + lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW + smb = smb.numpy() # eigen with scaled interpolation + + # ground truth + tru_damb = dist(a, b) + + # euclid + euc_damb = dist(a, emb) + dist(emb, b) + + # euclid + riemann_damb = dist(a, riemannmb) + dist(riemannmb, b) + + # euclid_sqrt + sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b) + + # chol + chol_damb = dist(a, cholmb) + dist(cholmb, b) + + # ew com + ewc_damb = dist(a, cmb) + dist(cmb, b) + + # ew lin + if 'linear_eigen' in blacklist: + ewl_damb = 0 + else: + ewl_damb = dist(a, lmb) + dist(lmb, b) + + # ew sqrt + if 'sqrt_eigen' in blacklist: + ewlsqrt_damb = 0 + else: + ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b) + + # ew sca sqrt + ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b) + + # ew lin riemann + ewlr_damb = dist(a, lrmb) + dist(lrmb, b) + + # ew riemann sqrt + ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b) + + # ew sca + ews_damb = dist(a, smb) + dist(smb, b) + + akku += dist(sqrtmb, lsqrtmb)/tru_damb + + return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb) + #except: + # print('num issue') + # return 0, 0, 0, 0 + +def testSingle(local=True): + a, b = genRandSPDs(local=local) + return calcErrors(a, b) + +def test(num=1024, local=True): + euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + for i in tqdm(range(num)): + res = False + while res == False: + res = testSingle(local=local) + euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res + euc_errs += euc_err + ewc_errs += ewc_err + ewl_errs += ewl_err + ews_errs += ews_err + chol_errs += chol_err + sqrt_errs += sqrt_err + ewlr_errs += ewlr_err + ewlrsqrt_errs += ewlrsqrt_err + ewlsqrt_errs += ewlsqrt_err + rie_errs += rie_err + ewlssqrt_errs += ewlssqrt_err + return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num + +names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen'] + +res = th.Tensor(test(num=num, local=True))/d*100 + +for n,r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False): + if not n in blacklist: + print(n+': '+'%.6f' % r+'%') + +print('---') +print(str(akku/num*100) + '%') + + +#--- +#s = 3.14159 +#d = 0.01 +#eps = 0.001 +#100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s] +#sqrt_eigen: 0.030458% +#scaled_sqrt_eigen: 0.030458% +#euclidean_chol: 0.033712% +#euclidean: 0.119651% +#scaled_eigen: 0.119896% +#linear_eigen: 0.119899% +#commutative_eigen: 0.154270% +#--- +#0.012237632813561243%