commit message

This commit is contained in:
Dominik Moritz Roth 2022-12-02 15:02:20 +01:00
parent 23236f18a4
commit 73a0271452

View File

@ -2,8 +2,9 @@ import torch as th
import geoopt import geoopt
import pymanopt import pymanopt
from tqdm import tqdm from tqdm import tqdm
import math
n = 3 n = 4
spd_alt = geoopt.SymmetricPositiveDefinite() spd_alt = geoopt.SymmetricPositiveDefinite()
spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1) spd = pymanopt.manifolds.positive_definite.SymmetricPositiveDefinite(n, k=1)
@ -11,228 +12,271 @@ so = pymanopt.manifolds.special_orthogonal_group.SpecialOrthogonalGroup(n, k=1)
dist = spd.dist dist = spd.dist
shape = (n,n) shape = (n, n)
eta = 1 eta = 1
pos = (eta)/(1+eta) pos = (eta)/(1+eta)
s = 3.14159 s = 3.14159
d = 0.01 d = 0.1
eps = 0.001 eps = 0.001
num = 1024 num = 1024
# ignore these...
fev = 1
few = 1
blacklist = [] blacklist = []
blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann'] blacklist = ['linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'riemann']
#blacklist += ['linear_eigen', 'sqrt_eigen'] #blacklist += ['linear_eigen', 'sqrt_eigen']
#blacklist += ['scaled_eigen'] blacklist += ['scaled_eigen']
#blacklist += ['euclidean_sqrt'] #blacklist += ['euclidean_prec']
#blacklist += ['euclidean_sqrt', 'commutative_eigen']
def genRandSPDs(local=True): def genRandSPDs(local=True):
if local: if local:
a = spd.random_point()*s a = spd.random_point()*s
#a = spd.random(shape) #a = spd.random(shape)
#return a, a + spd.random(shape)*d # return a, a + spd.random(shape)*d
return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(3)*eta return th.Tensor(a) + th.eye(n)*eps, th.Tensor(a + spd.random_point()*s*d) + th.eye(n)*eta
return spd.random(shape), spd.random(shape) return spd.random(shape), spd.random(shape)
akku = 0 akku = 0
def calcErrors(a, b): def calcErrors(a, b):
global akku global akku
### eigen decomp etaV, etaW = eta * fev, eta * few
ewa, eva = th.linalg.eigh(a)
ewb, evb = th.linalg.eigh(b)
ewa, eva = ewa.real, eva.real # eigen decomp
ewb, evb = ewb.real, evb.real ewa, eva = th.linalg.eigh(a)
ewb, evb = th.linalg.eigh(b)
if th.norm(eva-evb) > d*2: ewa, eva = ewa.real, eva.real
# EVs flipped; try again... ewb, evb = ewb.real, evb.real
return False
### euclidean approx (also depends on eigendecomp for fair comparison) if th.norm(eva-evb) > d*2:
ar = eva @ th.diag(ewa) @ eva.T # EVs flipped; try again...
br = evb @ th.diag(ewb) @ evb.T return False
emb = (ar + eta*br)/(1+eta) # euclidean approx (also depends on eigendecomp for fair comparison)
ar = eva @ th.diag(ewa) @ eva.T
br = evb @ th.diag(ewb) @ evb.T
### euclidean_sqrt emb = (ar + eta*br)/(1+eta)
asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T
bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T
mbsqrt = (asqrt + eta*bsqrt)/(1+eta) # euclidean_sqrt
sqrtmb = mbsqrt@mbsqrt asqrt = eva @ th.diag(th.sqrt(ewa)) @ eva.T
bsqrt = evb @ th.diag(th.sqrt(ewb)) @ evb.T
### chol approx (also depends on eigendecomp for fair comparison) mbsqrt = (asqrt + eta*bsqrt)/(1+eta)
la = th.linalg.cholesky(ar) sqrtmb = mbsqrt@mbsqrt
lb = th.linalg.cholesky(br)
cholmb_chol = (la + eta*lb)/(1+eta) # euclidean_prec
cholmb = cholmb_chol@cholmb_chol.T ainv = eva @ th.diag(1/ewa) @ eva.T
binv = evb @ th.diag(1/ewb) @ evb.T
### riemann if 'euclidean_prec' in blacklist:
if 'riemann' in blacklist: precmb = ar
riemannmb = ar else:
else: precmb = th.inverse((ainv + eta*binv)/(1+eta))
riemannmb = th.Tensor(spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float()
### eigen approx # chol approx (also depends on eigendecomp for fair comparison)
la = th.linalg.cholesky(ar)
lb = th.linalg.cholesky(br)
# com cholmb_chol = (la + eta*lb)/(1+eta)
ewmb, ecvmb = (ewa + eta*ewb)/(1+eta), eva cholmb = cholmb_chol@cholmb_chol.T
# lin # riemann
elvmb = (eva + eta*evb)/(1+eta) if 'riemann' in blacklist:
# not closed form, but stable gradients riemannmb = ar
elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva)) else:
if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0: riemannmb = th.Tensor(
elvmb = elvmb_retr spd.exp(ar.numpy(), pos*(spd.log(ar.numpy(), br.numpy())))).real.float()
# lin_riemann # eigen approx
if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist:
elrvmb = eva
else:
elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float()
# ew_sqrt # com
ewsqrtmb_sqrt = (th.sqrt(ewa) + eta*th.sqrt(ewb))/(1+eta) ewmb, ecvmb = (ewa + etaW*ewb)/(1+etaW), eva
ewsqrtmb = ewsqrtmb_sqrt**2
# ew scaling # lin
esvmb = ((eva@th.diag(ewa) + eta*evb@th.diag(ewb))/(1+eta))/ewmb elvmb = (eva + etaV*evb)/(1+etaV)
# not closed form, but stable gradients
elvmb_retr = th.Tensor(so.retraction(eva, elvmb-eva))
if so.norm(elvmb_retr, elvmb-elvmb_retr) > 1.0:
elvmb = elvmb_retr
cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T # lin_riemann
cmb = spd_alt.projx(cmb) if 'linear_riemann_eigen' in blacklist and 'linear_riemann_eigen_sqrt' in blacklist:
elrvmb = eva
else:
elrvmb = th.Tensor(so.exp(eva, pos*(so.log(eva, evb)))).real.float()
lmb = elvmb @ th.diag(ewmb) @ elvmb.T # ew_sqrt
#lmb = spd_alt.projx(lmb) ewsqrtmb_sqrt = (th.sqrt(ewa) + etaW*th.sqrt(ewb))/(1+etaW)
ewsqrtmb = ewsqrtmb_sqrt**2
# THIS # ew_inv
lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T ewinvmb = 1/((1/ewa + etaW*(1/ewb))/(1+etaW))
lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T # ew scaling
#lrmb = spd_alt.projx(lrmb) esvmb = ((eva@th.diag(ewa) + etaV*evb@th.diag(ewb))/(1+etaV))/ewmb
lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T cmb = ecvmb @ th.diag(ewmb) @ ecvmb.T
cmb = spd_alt.projx(cmb)
smb = esvmb @ th.diag(ewmb) @ esvmb.T lmb = elvmb @ th.diag(ewmb) @ elvmb.T
#smb = spd_alt.projx(smb) #lmb = spd_alt.projx(lmb)
lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T # THIS
lsqrtmb = elvmb @ th.diag(ewsqrtmb) @ elvmb.T
### checking linvmb = elvmb @ th.diag(ewinvmb) @ elvmb.T
if True:
a = a.numpy() # Sigma_old
b = b.numpy() # Sigma
emb = emb.numpy() # euclidean line lrmb = elrvmb @ th.diag(ewmb) @ elrvmb.T
cholmb = cholmb.numpy() # line in chol space #lrmb = spd_alt.projx(lrmb)
sqrtmb = sqrtmb.numpy() # line in spq-matrix space
riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case)
cmb = cmb.numpy() # eigen under commutative assumption
lmb = lmb.numpy() # eigen with linear basis interpolation
lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic
lsqrtmb = lsqrtmb.numpy() # eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol)
lrsqrtmb = lrsqrtmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW
lssqrtmb = lssqrtmb.numpy() # eigen with scaled eigenbasis interpolation and sqrt interpol for EW
smb = smb.numpy() # eigen with scaled interpolation
# ground truth lrsqrtmb = elrvmb @ th.diag(ewsqrtmb) @ elrvmb.T
tru_damb = dist(a, b)
# euclid smb = esvmb @ th.diag(ewmb) @ esvmb.T
euc_damb = dist(a, emb) + dist(emb, b) #smb = spd_alt.projx(smb)
# euclid lssqrtmb = esvmb @ th.diag(ewsqrtmb) @ esvmb.T
riemann_damb = dist(a, riemannmb) + dist(riemannmb, b)
# euclid_sqrt # checking
sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b) if True:
a = a.numpy() # Sigma_old
b = b.numpy() # Sigma
# chol emb = emb.numpy() # euclidean line
chol_damb = dist(a, cholmb) + dist(cholmb, b) cholmb = cholmb.numpy() # line in chol space
sqrtmb = sqrtmb.numpy() # line in spq-matrix space
riemannmb = riemannmb.numpy() # spd geodesic (theoretical best case)
cmb = cmb.numpy() # eigen under commutative assumption
lmb = lmb.numpy() # eigen with linear basis interpolation
lrmb = lrmb.numpy() # eigen with eigenbasis interpolation along so(n) geodesic
# eigen with linear eigenbasis interpolation and sqrt interpol for EW (=std interpol)
lsqrtmb = lsqrtmb.numpy()
# eigen with eigenbasis interpolation along so(n) geodesic and sqrt interpol for EW
lrsqrtmb = lrsqrtmb.numpy()
# eigen with scaled eigenbasis interpolation and sqrt interpol for EW
lssqrtmb = lssqrtmb.numpy()
smb = smb.numpy() # eigen with scaled interpolation
precmb = precmb.numpy()
linvmb = linvmb.numpy()
# ew com # ground truth
ewc_damb = dist(a, cmb) + dist(cmb, b) tru_damb = dist(a, b)
# ew lin # euclid
if 'linear_eigen' in blacklist: euc_damb = dist(a, emb) + dist(emb, b)
ewl_damb = 0
else:
ewl_damb = dist(a, lmb) + dist(lmb, b)
# ew sqrt # euclid
if 'sqrt_eigen' in blacklist: riemann_damb = dist(a, riemannmb) + dist(riemannmb, b)
ewlsqrt_damb = 0
else:
ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b)
# ew sca sqrt # euclid_sqrt
ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b) sqrt_damb = dist(a, sqrtmb) + dist(sqrtmb, b)
# ew lin riemann # prec
ewlr_damb = dist(a, lrmb) + dist(lrmb, b) prec_damb = dist(a, precmb) + dist(precmb, b)
# ew riemann sqrt # chol
ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b) chol_damb = dist(a, cholmb) + dist(cholmb, b)
# ew sca # ew com
ews_damb = dist(a, smb) + dist(smb, b) ewc_damb = dist(a, cmb) + dist(cmb, b)
akku += dist(sqrtmb, lsqrtmb)/tru_damb # ew lin
if 'linear_eigen' in blacklist:
ewl_damb = 0
else:
ewl_damb = dist(a, lmb) + dist(lmb, b)
# ew inv
ewinv_damb = dist(a, linvmb) + dist(linvmb, b)
# ew sqrt
if 'sqrt_eigen' in blacklist:
ewlsqrt_damb = 0
else:
ewlsqrt_damb = dist(a, lsqrtmb) + dist(lsqrtmb, b)
# ew sca sqrt
ewlssqrt_damb = dist(a, lssqrtmb) + dist(lssqrtmb, b)
# ew lin riemann
ewlr_damb = dist(a, lrmb) + dist(lrmb, b)
# ew riemann sqrt
ewlrsqrt_damb = dist(a, lrsqrtmb) + dist(lrsqrtmb, b)
# ew sca
ews_damb = dist(a, smb) + dist(smb, b)
akku += tru_damb
return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb), abs(prec_damb-tru_damb), abs(ewinv_damb-tru_damb)
# except:
# print('num issue')
# return 0, 0, 0, 0
return abs(euc_damb-tru_damb), abs(ewc_damb-tru_damb), abs(ewl_damb-tru_damb), abs(ews_damb-tru_damb), abs(chol_damb-tru_damb), abs(sqrt_damb-tru_damb), abs(ewlr_damb-tru_damb), abs(ewlrsqrt_damb-tru_damb), abs(ewlsqrt_damb-tru_damb), abs(riemann_damb-tru_damb), abs(ewlssqrt_damb-tru_damb)
#except:
# print('num issue')
# return 0, 0, 0, 0
def testSingle(local=True): def testSingle(local=True):
a, b = genRandSPDs(local=local) a, b = genRandSPDs(local=local)
return calcErrors(a, b) return calcErrors(a, b)
def test(num=1024, local=True): def test(num=1024, local=True):
euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 euc_errs, ewc_errs, ewl_errs, ews_errs, chol_errs, sqrt_errs, ewlr_errs, ewlrsqrt_errs, ewlsqrt_errs, rie_errs, ewlssqrt_errs, prec_errs, ewinv_errs = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
for i in tqdm(range(num)): for i in tqdm(range(num)):
res = False res = False
while res == False: while res == False:
res = testSingle(local=local) res = testSingle(local=local)
euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err = res euc_err, ewc_err, ewl_err, ews_err, chol_err, sqrt_err, ewlr_err, ewlrsqrt_err, ewlsqrt_err, rie_err, ewlssqrt_err, prec_err, ewinv_err = res
euc_errs += euc_err euc_errs += euc_err
ewc_errs += ewc_err ewc_errs += ewc_err
ewl_errs += ewl_err ewl_errs += ewl_err
ews_errs += ews_err ews_errs += ews_err
chol_errs += chol_err chol_errs += chol_err
sqrt_errs += sqrt_err sqrt_errs += sqrt_err
ewlr_errs += ewlr_err ewlr_errs += ewlr_err
ewlrsqrt_errs += ewlrsqrt_err ewlrsqrt_errs += ewlrsqrt_err
ewlsqrt_errs += ewlsqrt_err ewlsqrt_errs += ewlsqrt_err
rie_errs += rie_err rie_errs += rie_err
ewlssqrt_errs += ewlssqrt_err ewlssqrt_errs += ewlssqrt_err
return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num prec_errs += prec_err
ewinv_errs += ewinv_err
names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt', 'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen'] return euc_errs/num, ewc_errs/num, ewl_errs/num, ews_errs/num, chol_errs/num, sqrt_errs/num, ewlr_errs/num, ewlrsqrt_errs/num, ewlsqrt_errs/num, rie_errs/num, ewlsqrt_errs/num, prec_errs/num, ewinv_errs/num
res = th.Tensor(test(num=num, local=True))/d*100
for n,r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False):
if not n in blacklist:
print(n+': '+'%.6f' % r+'%')
print('---')
print(str(akku/num*100) + '%')
#--- names = ['euclidean', 'commutative_eigen', 'linear_eigen', 'scaled_eigen', 'euclidean_chol', 'euclidean_sqrt',
'linear_riemann_eigen', 'linear_riemann_eigen_sqrt', 'sqrt_eigen', 'riemann', 'scaled_sqrt_eigen', 'euclidean_prec', 'inv_eigen']
def main():
res = th.Tensor(test(num=num, local=True))/(akku/num)*100
for n, r in sorted(zip(names, res), key=lambda x: float(x[1].item()), reverse=False):
if not n in blacklist:
print(n+': '+'%.6f' % r+'%')
if __name__ == '__main__':
main()
# ---
#s = 3.14159 #s = 3.14159
#d = 0.01 #d = 0.01
#eps = 0.001 #eps = 0.001
#100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s] # 100%|██████████████████████████████████████████████████████████████████████████████████████████| 131072/131072 [08:24<00:00, 259.59it/s]
#sqrt_eigen: 0.030458% # sqrt_eigen: 0.030458%
#scaled_sqrt_eigen: 0.030458% # scaled_sqrt_eigen: 0.030458%
#euclidean_chol: 0.033712% # euclidean_chol: 0.033712%
#euclidean: 0.119651% # euclidean: 0.119651%
#scaled_eigen: 0.119896% # scaled_eigen: 0.119896%
#linear_eigen: 0.119899% # linear_eigen: 0.119899%
#commutative_eigen: 0.154270% # commutative_eigen: 0.154270%
#--- # ---
#0.012237632813561243% # 0.012237632813561243%