diff --git a/spdesic.py b/spdesic.py index 773bdfa..ce1a6ea 100644 --- a/spdesic.py +++ b/spdesic.py @@ -2,8 +2,9 @@ import torch as th import geoopt import pymanopt from tqdm import tqdm +import math -n = 3 +n = 4 spd_alt = geoopt.SymmetricPositiveDefinite() spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1) @@ -11,228 +12,271 @@ so = pymanopt.manifolds.special_orthogonal_group.SpecialOrthogonalGroup(n, k=1) dist = spd.dist -shape = (n,n) +shape = (n, n) eta = 1 pos = (eta)/(1+eta) s = 3.14159 -d = 0.01 +d = 0.1 eps = 0.001 num = 1024 +# ignore these... +fev = 1 +few = 1 blacklist = [] blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann'] #blacklist += ['linear_eigen', 'sqrt_eigen'] -#blacklist += ['scaled_eigen'] -#blacklist += ['euclidean_sqrt'] +blacklist += ['scaled_eigen'] +#blacklist += ['euclidean_prec'] +#blacklist += ['euclidean_sqrt', 'commutative_eigen'] + def genRandSPDs(local=True): - if local: - a = spd.random_point()*s - #a = spd.random(shape) - #return a, a + spd.random(shape)*d - return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta - return spd.random(shape), spd.random(shape) + if local: + a = spd.random_point()*s + #a = spd.random(shape) + # return a, a + spd.random(shape)*d + return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(n)*eta + return spd.random(shape), spd.random(shape) + akku = 0 + def calcErrors(a, b): - global akku - ### eigen decomp - ewa, eva = th.linalg.eigh(a) - ewb, evb = th.linalg.eigh(b) + global akku + etaV, etaW = eta * fev, eta * few - ewa, eva = ewa.real, eva.real - ewb, evb = ewb.real, evb.real + # eigen decomp + ewa, eva = th.linalg.eigh(a) + ewb, evb = th.linalg.eigh(b) - if th.norm(eva-evb) > d*2: - # EVs flipped; try again... - return False + ewa, eva = ewa.real, eva.real + ewb, evb = ewb.real, evb.real - ### euclidean approx (also depends on eigendecomp for fair comparison) - ar = eva @ th.diag(ewa) @ eva.T - br = evb @ th.diag(ewb) @ evb.T - - emb = (ar + eta*br)/(1+eta) + if th.norm(eva-evb) > d*2: + # EVs flipped; try again... + return False - ### euclidean_sqrt - asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T - bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T + # euclidean approx (also depends on eigendecomp for fair comparison) + ar = eva @ th.diag(ewa) @ eva.T + br = evb @ th.diag(ewb) @ evb.T - mbsqrt = (asqrt + eta*bsqrt)/(1+eta) - sqrtmb = mbsqrt@mbsqrt + emb = (ar + eta*br)/(1+eta) - ### chol approx (also depends on eigendecomp for fair comparison) - la = th.linalg.cholesky(ar) - lb = th.linalg.cholesky(br) - - cholmb_chol = (la + eta*lb)/(1+eta) - cholmb = cholmb_chol@cholmb_chol.T + # euclidean_sqrt + asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T + bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T - ### riemann - if 'riemann' in blacklist: - riemannmb = ar - else: - riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float() + mbsqrt = (asqrt + eta*bsqrt)/(1+eta) + sqrtmb = mbsqrt@mbsqrt - ### eigen approx - - # com - ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva + # euclidean_prec + ainv = eva @ th.diag(1/ewa) @ eva.T + binv = evb @ th.diag(1/ewb) @ evb.T - # lin - elvmb = (eva + eta*evb)/(1+eta) - # not closed form, but stable gradients - elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva)) - if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0: - elvmb = elvmb_retr + if 'euclidean_prec' in blacklist: + precmb = ar + else: + precmb = th.inverse((ainv + eta*binv)/(1+eta)) - # lin_riemann - if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist: - elrvmb = eva - else: - elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float() - - # ew_sqrt - ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta) - ewsqrtmb = ewsqrtmb_sqrt**2 + # chol approx (also depends on eigendecomp for fair comparison) + la = th.linalg.cholesky(ar) + lb = th.linalg.cholesky(br) - # ew scaling - esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb + cholmb_chol = (la + eta*lb)/(1+eta) + cholmb = cholmb_chol@cholmb_chol.T - cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T - cmb = spd_alt.projx(cmb) + # riemann + if 'riemann' in blacklist: + riemannmb = ar + else: + riemannmb = th.Tensor( + spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float() - lmb = elvmb @ th.diag(ewmb) @ elvmb.T - #lmb = spd_alt.projx(lmb) - - # THIS - lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T - - lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T - #lrmb = spd_alt.projx(lrmb) + # eigen approx - lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T + # com + ewmb, ecvmb = (ewa + etaW*ewb)/(1+etaW), eva - smb = esvmb @ th.diag(ewmb) @ esvmb.T - #smb = spd_alt.projx(smb) + # lin + elvmb = (eva + etaV*evb)/(1+etaV) + # not closed form, but stable gradients + elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva)) + if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0: + elvmb = elvmb_retr - lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T + # lin_riemann + if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist: + elrvmb = eva + else: + elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float() - ### checking - if True: - a = a.numpy() # Sigma_old - b = b.numpy() # Sigma + # ew_sqrt + ewsqrtmb_sqrt = (th.sqrt(ewa) + etaW*th.sqrt(ewb))/(1+etaW) + ewsqrtmb = ewsqrtmb_sqrt**2 - emb = emb.numpy() # euclidean line - cholmb = cholmb.numpy() # line in chol space - sqrtmb = sqrtmb.numpy() # line in spq-matrix space - riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case) - cmb = cmb.numpy() # eigen under commutative assumption - lmb = lmb.numpy() # eigen with linear basis interpolation - lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic - lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol) - lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW - lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW - smb = smb.numpy() # eigen with scaled interpolation + # ew_inv + ewinvmb = 1/((1/ewa + etaW*(1/ewb))/(1+etaW)) - # ground truth - tru_damb = dist(a, b) - - # euclid - euc_damb = dist(a, emb) + dist(emb, b) + # ew scaling + esvmb = ((eva@th.diag(ewa) + etaV*evb@th.diag(ewb))/(1+etaV))/ewmb - # euclid - riemann_damb = dist(a, riemannmb) + dist(riemannmb, b) + cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T + cmb = spd_alt.projx(cmb) - # euclid_sqrt - sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b) - - # chol - chol_damb = dist(a, cholmb) + dist(cholmb, b) - - # ew com - ewc_damb = dist(a, cmb) + dist(cmb, b) - - # ew lin - if 'linear_eigen' in blacklist: - ewl_damb = 0 - else: - ewl_damb = dist(a, lmb) + dist(lmb, b) - - # ew sqrt - if 'sqrt_eigen' in blacklist: - ewlsqrt_damb = 0 - else: - ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b) - - # ew sca sqrt - ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b) + lmb = elvmb @ th.diag(ewmb) @ elvmb.T + #lmb = spd_alt.projx(lmb) + + # THIS + lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T + + linvmb = elvmb @ th.diag(ewinvmb) @ elvmb.T + + lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T + #lrmb = spd_alt.projx(lrmb) + + lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T + + smb = esvmb @ th.diag(ewmb) @ esvmb.T + #smb = spd_alt.projx(smb) + + lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T + + # checking + if True: + a = a.numpy() # Sigma_old + b = b.numpy() # Sigma + + emb = emb.numpy() # euclidean line + cholmb = cholmb.numpy() # line in chol space + sqrtmb = sqrtmb.numpy() # line in spq-matrix space + riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case) + cmb = cmb.numpy() # eigen under commutative assumption + lmb = lmb.numpy() # eigen with linear basis interpolation + lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic + # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol) + lsqrtmb = lsqrtmb.numpy() + # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW + lrsqrtmb = lrsqrtmb.numpy() + # eigen with scaled eigenbasis interpolation and sqrt interpol for EW + lssqrtmb = lssqrtmb.numpy() + smb = smb.numpy() # eigen with scaled interpolation + precmb = precmb.numpy() + linvmb = linvmb.numpy() + + # ground truth + tru_damb = dist(a, b) + + # euclid + euc_damb = dist(a, emb) + dist(emb, b) + + # euclid + riemann_damb = dist(a, riemannmb) + dist(riemannmb, b) + + # euclid_sqrt + sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b) + + # prec + prec_damb = dist(a, precmb) + dist(precmb, b) + + # chol + chol_damb = dist(a, cholmb) + dist(cholmb, b) + + # ew com + ewc_damb = dist(a, cmb) + dist(cmb, b) + + # ew lin + if 'linear_eigen' in blacklist: + ewl_damb = 0 + else: + ewl_damb = dist(a, lmb) + dist(lmb, b) + + # ew inv + ewinv_damb = dist(a, linvmb) + dist(linvmb, b) + + # ew sqrt + if 'sqrt_eigen' in blacklist: + ewlsqrt_damb = 0 + else: + ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b) + + # ew sca sqrt + ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b) + + # ew lin riemann + ewlr_damb = dist(a, lrmb) + dist(lrmb, b) + + # ew riemann sqrt + ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b) + + # ew sca + ews_damb = dist(a, smb) + dist(smb, b) + + akku += tru_damb + + return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb), abs(prec_damb-tru_damb), abs(ewinv_damb-tru_damb) + # except: + # print('num issue') + # return 0, 0, 0, 0 - # ew lin riemann - ewlr_damb = dist(a, lrmb) + dist(lrmb, b) - # ew riemann sqrt - ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b) - - # ew sca - ews_damb = dist(a, smb) + dist(smb, b) - - akku += dist(sqrtmb, lsqrtmb)/tru_damb - - return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb) - #except: - # print('num issue') - # return 0, 0, 0, 0 - def testSingle(local=True): - a, b = genRandSPDs(local=local) - return calcErrors(a, b) + a, b = genRandSPDs(local=local) + return calcErrors(a, b) + def test(num=1024, local=True): - euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - for i in tqdm(range(num)): - res = False - while res == False: - res = testSingle(local=local) - euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res - euc_errs += euc_err - ewc_errs += ewc_err - ewl_errs += ewl_err - ews_errs += ews_err - chol_errs += chol_err - sqrt_errs += sqrt_err - ewlr_errs += ewlr_err - ewlrsqrt_errs += ewlrsqrt_err - ewlsqrt_errs += ewlsqrt_err - rie_errs += rie_err - ewlssqrt_errs += ewlssqrt_err - return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num - -names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen'] - -res = th.Tensor(test(num=num, local=True))/d*100 - -for n,r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False): - if not n in blacklist: - print(n+': '+'%.6f' % r+'%') - -print('---') -print(str(akku/num*100) + '%') + euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs, prec_errs, ewinv_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + for i in tqdm(range(num)): + res = False + while res == False: + res = testSingle(local=local) + euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err, prec_err, ewinv_err = res + euc_errs += euc_err + ewc_errs += ewc_err + ewl_errs += ewl_err + ews_errs += ews_err + chol_errs += chol_err + sqrt_errs += sqrt_err + ewlr_errs += ewlr_err + ewlrsqrt_errs += ewlrsqrt_err + ewlsqrt_errs += ewlsqrt_err + rie_errs += rie_err + ewlssqrt_errs += ewlssqrt_err + prec_errs += prec_err + ewinv_errs += ewinv_err + return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num, prec_errs/num, ewinv_errs/num -#--- +names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', + 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen', 'euclidean_prec', 'inv_eigen'] + + +def main(): + res = th.Tensor(test(num=num, local=True))/(akku/num)*100 + + for n, r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False): + if not n in blacklist: + print(n+': '+'%.6f' % r+'%') + + +if __name__ == '__main__': + main() + +# --- #s = 3.14159 #d = 0.01 #eps = 0.001 -#100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s] -#sqrt_eigen: 0.030458% -#scaled_sqrt_eigen: 0.030458% -#euclidean_chol: 0.033712% -#euclidean: 0.119651% -#scaled_eigen: 0.119896% -#linear_eigen: 0.119899% -#commutative_eigen: 0.154270% -#--- -#0.012237632813561243% +# 100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s] +# sqrt_eigen: 0.030458% +# scaled_sqrt_eigen: 0.030458% +# euclidean_chol: 0.033712% +# euclidean: 0.119651% +# scaled_eigen: 0.119896% +# linear_eigen: 0.119899% +# commutative_eigen: 0.154270% +# --- +# 0.012237632813561243%