import torch as th import geoopt import pymanopt from tqdm import tqdm n = 3 spd_alt = geoopt.SymmetricPositiveDefinite() spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1) so = pymanopt.manifolds.special_orthogonal_group.SpecialOrthogonalGroup(n, k=1) dist = spd.dist shape = (n,n) eta = 1 pos = (eta)/(1+eta) s = 3.14159 d = 0.01 eps = 0.001 num = 1024 blacklist = [] blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann'] #blacklist += ['linear_eigen', 'sqrt_eigen'] #blacklist += ['scaled_eigen'] #blacklist += ['euclidean_sqrt'] def genRandSPDs(local=True): if local: a = spd.random_point()*s #a = spd.random(shape) #return a, a + spd.random(shape)*d return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta return spd.random(shape), spd.random(shape) akku = 0 def calcErrors(a, b): global akku ### eigen decomp ewa, eva = th.linalg.eigh(a) ewb, evb = th.linalg.eigh(b) ewa, eva = ewa.real, eva.real ewb, evb = ewb.real, evb.real if th.norm(eva-evb) > d*2: # EVs flipped; try again... return False ### euclidean approx (also depends on eigendecomp for fair comparison) ar = eva @ th.diag(ewa) @ eva.T br = evb @ th.diag(ewb) @ evb.T emb = (ar + eta*br)/(1+eta) ### euclidean_sqrt asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T mbsqrt = (asqrt + eta*bsqrt)/(1+eta) sqrtmb = mbsqrt@mbsqrt ### chol approx (also depends on eigendecomp for fair comparison) la = th.linalg.cholesky(ar) lb = th.linalg.cholesky(br) cholmb_chol = (la + eta*lb)/(1+eta) cholmb = cholmb_chol@cholmb_chol.T ### riemann if 'riemann' in blacklist: riemannmb = ar else: riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float() ### eigen approx # com ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva # lin elvmb = (eva + eta*evb)/(1+eta) # not closed form, but stable gradients elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva)) if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0: elvmb = elvmb_retr # lin_riemann if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist: elrvmb = eva else: elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float() # ew_sqrt ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta) ewsqrtmb = ewsqrtmb_sqrt**2 # ew scaling esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T cmb = spd_alt.projx(cmb) lmb = elvmb @ th.diag(ewmb) @ elvmb.T #lmb = spd_alt.projx(lmb) # THIS lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T #lrmb = spd_alt.projx(lrmb) lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T smb = esvmb @ th.diag(ewmb) @ esvmb.T #smb = spd_alt.projx(smb) lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T ### checking if True: a = a.numpy() # Sigma_old b = b.numpy() # Sigma emb = emb.numpy() # euclidean line cholmb = cholmb.numpy() # line in chol space sqrtmb = sqrtmb.numpy() # line in spq-matrix space riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case) cmb = cmb.numpy() # eigen under commutative assumption lmb = lmb.numpy() # eigen with linear basis interpolation lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol) lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW smb = smb.numpy() # eigen with scaled interpolation # ground truth tru_damb = dist(a, b) # euclid euc_damb = dist(a, emb) + dist(emb, b) # euclid riemann_damb = dist(a, riemannmb) + dist(riemannmb, b) # euclid_sqrt sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b) # chol chol_damb = dist(a, cholmb) + dist(cholmb, b) # ew com ewc_damb = dist(a, cmb) + dist(cmb, b) # ew lin if 'linear_eigen' in blacklist: ewl_damb = 0 else: ewl_damb = dist(a, lmb) + dist(lmb, b) # ew sqrt if 'sqrt_eigen' in blacklist: ewlsqrt_damb = 0 else: ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b) # ew sca sqrt ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b) # ew lin riemann ewlr_damb = dist(a, lrmb) + dist(lrmb, b) # ew riemann sqrt ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b) # ew sca ews_damb = dist(a, smb) + dist(smb, b) akku += dist(sqrtmb, lsqrtmb)/tru_damb return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb) #except: # print('num issue') # return 0, 0, 0, 0 def testSingle(local=True): a, b = genRandSPDs(local=local) return calcErrors(a, b) def test(num=1024, local=True): euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 for i in tqdm(range(num)): res = False while res == False: res = testSingle(local=local) euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res euc_errs += euc_err ewc_errs += ewc_err ewl_errs += ewl_err ews_errs += ews_err chol_errs += chol_err sqrt_errs += sqrt_err ewlr_errs += ewlr_err ewlrsqrt_errs += ewlrsqrt_err ewlsqrt_errs += ewlsqrt_err rie_errs += rie_err ewlssqrt_errs += ewlssqrt_err return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen'] res = th.Tensor(test(num=num, local=True))/d*100 for n,r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False): if not n in blacklist: print(n+': '+'%.6f' % r+'%') print('---') print(str(akku/num*100) + '%') #--- #s = 3.14159 #d = 0.01 #eps = 0.001 #100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s] #sqrt_eigen: 0.030458% #scaled_sqrt_eigen: 0.030458% #euclidean_chol: 0.033712% #euclidean: 0.119651% #scaled_eigen: 0.119896% #linear_eigen: 0.119899% #commutative_eigen: 0.154270% #--- #0.012237632813561243%