initial commit
This commit is contained in:
commit
6f43bfb719
238
spdesic.py
Normal file
238
spdesic.py
Normal file
@ -0,0 +1,238 @@
|
||||
import torch as th
|
||||
import geoopt
|
||||
import pymanopt
|
||||
from tqdm import tqdm
|
||||
|
||||
n = 3
|
||||
|
||||
spd_alt = geoopt.SymmetricPositiveDefinite()
|
||||
spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1)
|
||||
so = pymanopt.manifolds.special_orthogonal_group.SpecialOrthogonalGroup(n, k=1)
|
||||
|
||||
dist = spd.dist
|
||||
|
||||
shape = (n,n)
|
||||
eta = 1
|
||||
pos = (eta)/(1+eta)
|
||||
s = 3.14159
|
||||
d = 0.01
|
||||
eps = 0.001
|
||||
|
||||
num = 1024
|
||||
|
||||
|
||||
blacklist = []
|
||||
blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann']
|
||||
#blacklist += ['linear_eigen', 'sqrt_eigen']
|
||||
#blacklist += ['scaled_eigen']
|
||||
#blacklist += ['euclidean_sqrt']
|
||||
|
||||
def genRandSPDs(local=True):
|
||||
if local:
|
||||
a = spd.random_point()*s
|
||||
#a = spd.random(shape)
|
||||
#return a, a + spd.random(shape)*d
|
||||
return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta
|
||||
return spd.random(shape), spd.random(shape)
|
||||
|
||||
akku = 0
|
||||
|
||||
def calcErrors(a, b):
|
||||
global akku
|
||||
### eigen decomp
|
||||
ewa, eva = th.linalg.eigh(a)
|
||||
ewb, evb = th.linalg.eigh(b)
|
||||
|
||||
ewa, eva = ewa.real, eva.real
|
||||
ewb, evb = ewb.real, evb.real
|
||||
|
||||
if th.norm(eva-evb) > d*2:
|
||||
# EVs flipped; try again...
|
||||
return False
|
||||
|
||||
### euclidean approx (also depends on eigendecomp for fair comparison)
|
||||
ar = eva @ th.diag(ewa) @ eva.T
|
||||
br = evb @ th.diag(ewb) @ evb.T
|
||||
|
||||
emb = (ar + eta*br)/(1+eta)
|
||||
|
||||
### euclidean_sqrt
|
||||
asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T
|
||||
bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T
|
||||
|
||||
mbsqrt = (asqrt + eta*bsqrt)/(1+eta)
|
||||
sqrtmb = mbsqrt@mbsqrt
|
||||
|
||||
### chol approx (also depends on eigendecomp for fair comparison)
|
||||
la = th.linalg.cholesky(ar)
|
||||
lb = th.linalg.cholesky(br)
|
||||
|
||||
cholmb_chol = (la + eta*lb)/(1+eta)
|
||||
cholmb = cholmb_chol@cholmb_chol.T
|
||||
|
||||
### riemann
|
||||
if 'riemann' in blacklist:
|
||||
riemannmb = ar
|
||||
else:
|
||||
riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float()
|
||||
|
||||
### eigen approx
|
||||
|
||||
# com
|
||||
ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva
|
||||
|
||||
# lin
|
||||
elvmb = (eva + eta*evb)/(1+eta)
|
||||
# not closed form, but stable gradients
|
||||
elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva))
|
||||
if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0:
|
||||
elvmb = elvmb_retr
|
||||
|
||||
# lin_riemann
|
||||
if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist:
|
||||
elrvmb = eva
|
||||
else:
|
||||
elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float()
|
||||
|
||||
# ew_sqrt
|
||||
ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta)
|
||||
ewsqrtmb = ewsqrtmb_sqrt**2
|
||||
|
||||
# ew scaling
|
||||
esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb
|
||||
|
||||
cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T
|
||||
cmb = spd_alt.projx(cmb)
|
||||
|
||||
lmb = elvmb @ th.diag(ewmb) @ elvmb.T
|
||||
#lmb = spd_alt.projx(lmb)
|
||||
|
||||
# THIS
|
||||
lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T
|
||||
|
||||
lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T
|
||||
#lrmb = spd_alt.projx(lrmb)
|
||||
|
||||
lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T
|
||||
|
||||
smb = esvmb @ th.diag(ewmb) @ esvmb.T
|
||||
#smb = spd_alt.projx(smb)
|
||||
|
||||
lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T
|
||||
|
||||
### checking
|
||||
if True:
|
||||
a = a.numpy() # Sigma_old
|
||||
b = b.numpy() # Sigma
|
||||
|
||||
emb = emb.numpy() # euclidean line
|
||||
cholmb = cholmb.numpy() # line in chol space
|
||||
sqrtmb = sqrtmb.numpy() # line in spq-matrix space
|
||||
riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case)
|
||||
cmb = cmb.numpy() # eigen under commutative assumption
|
||||
lmb = lmb.numpy() # eigen with linear basis interpolation
|
||||
lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic
|
||||
lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol)
|
||||
lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW
|
||||
lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW
|
||||
smb = smb.numpy() # eigen with scaled interpolation
|
||||
|
||||
# ground truth
|
||||
tru_damb = dist(a, b)
|
||||
|
||||
# euclid
|
||||
euc_damb = dist(a, emb) + dist(emb, b)
|
||||
|
||||
# euclid
|
||||
riemann_damb = dist(a, riemannmb) + dist(riemannmb, b)
|
||||
|
||||
# euclid_sqrt
|
||||
sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b)
|
||||
|
||||
# chol
|
||||
chol_damb = dist(a, cholmb) + dist(cholmb, b)
|
||||
|
||||
# ew com
|
||||
ewc_damb = dist(a, cmb) + dist(cmb, b)
|
||||
|
||||
# ew lin
|
||||
if 'linear_eigen' in blacklist:
|
||||
ewl_damb = 0
|
||||
else:
|
||||
ewl_damb = dist(a, lmb) + dist(lmb, b)
|
||||
|
||||
# ew sqrt
|
||||
if 'sqrt_eigen' in blacklist:
|
||||
ewlsqrt_damb = 0
|
||||
else:
|
||||
ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b)
|
||||
|
||||
# ew sca sqrt
|
||||
ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b)
|
||||
|
||||
# ew lin riemann
|
||||
ewlr_damb = dist(a, lrmb) + dist(lrmb, b)
|
||||
|
||||
# ew riemann sqrt
|
||||
ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b)
|
||||
|
||||
# ew sca
|
||||
ews_damb = dist(a, smb) + dist(smb, b)
|
||||
|
||||
akku += dist(sqrtmb, lsqrtmb)/tru_damb
|
||||
|
||||
return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb)
|
||||
#except:
|
||||
# print('num issue')
|
||||
# return 0, 0, 0, 0
|
||||
|
||||
def testSingle(local=True):
|
||||
a, b = genRandSPDs(local=local)
|
||||
return calcErrors(a, b)
|
||||
|
||||
def test(num=1024, local=True):
|
||||
euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
||||
for i in tqdm(range(num)):
|
||||
res = False
|
||||
while res == False:
|
||||
res = testSingle(local=local)
|
||||
euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res
|
||||
euc_errs += euc_err
|
||||
ewc_errs += ewc_err
|
||||
ewl_errs += ewl_err
|
||||
ews_errs += ews_err
|
||||
chol_errs += chol_err
|
||||
sqrt_errs += sqrt_err
|
||||
ewlr_errs += ewlr_err
|
||||
ewlrsqrt_errs += ewlrsqrt_err
|
||||
ewlsqrt_errs += ewlsqrt_err
|
||||
rie_errs += rie_err
|
||||
ewlssqrt_errs += ewlssqrt_err
|
||||
return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num
|
||||
|
||||
names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen']
|
||||
|
||||
res = th.Tensor(test(num=num, local=True))/d*100
|
||||
|
||||
for n,r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False):
|
||||
if not n in blacklist:
|
||||
print(n+': '+'%.6f' % r+'%')
|
||||
|
||||
print('---')
|
||||
print(str(akku/num*100) + '%')
|
||||
|
||||
|
||||
#---
|
||||
#s = 3.14159
|
||||
#d = 0.01
|
||||
#eps = 0.001
|
||||
#100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s]
|
||||
#sqrt_eigen: 0.030458%
|
||||
#scaled_sqrt_eigen: 0.030458%
|
||||
#euclidean_chol: 0.033712%
|
||||
#euclidean: 0.119651%
|
||||
#scaled_eigen: 0.119896%
|
||||
#linear_eigen: 0.119899%
|
||||
#commutative_eigen: 0.154270%
|
||||
#---
|
||||
#0.012237632813561243%
|
Loading…
Reference in New Issue
Block a user