Tidying
This commit is contained in:
		
							parent
							
								
									0702213e84
								
							
						
					
					
						commit
						6c7fc37116
					
				
							
								
								
									
										134
									
								
								metastable_projections/misc/distTools.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								metastable_projections/misc/distTools.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | |||||||
|  | import torch as th | ||||||
|  | 
 | ||||||
|  | from stable_baselines3.common.distributions import Distribution as SB3_Distribution | ||||||
|  | 
 | ||||||
|  | from ..distributions import UniversalGaussianDistribution, AnyDistribution | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_mean_and_chol(p: AnyDistribution, expand=False): | ||||||
|  |     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||||
|  |         if expand: | ||||||
|  |             return p.mean, th.diag_embed(p.stddev) | ||||||
|  |         else: | ||||||
|  |             return p.mean, p.stddev | ||||||
|  |     elif isinstance(p, th.distributions.MultivariateNormal): | ||||||
|  |         return p.mean, p.scale_tril | ||||||
|  |     elif isinstance(p, SB3_Distribution): | ||||||
|  |         return get_mean_and_chol(p.distribution, expand=expand) | ||||||
|  |     else: | ||||||
|  |         raise Exception('Dist-Type not implemented') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): | ||||||
|  |     if not hasattr(p, 'cov_sqrt'): | ||||||
|  |         raise Exception( | ||||||
|  |             'Distribution was not induced from sqrt. On-demand calculation is not supported.') | ||||||
|  |     else: | ||||||
|  |         mean, chol = get_mean_and_chol(p, expand=False) | ||||||
|  |         sqrt_cov = p.cov_sqrt | ||||||
|  |         if mean.shape[0] != sqrt_cov.shape[0]: | ||||||
|  |             shape = list(sqrt_cov.shape) | ||||||
|  |             shape[0] = mean.shape[0] | ||||||
|  |             shape = tuple(shape) | ||||||
|  |             sqrt_cov = sqrt_cov.expand(shape) | ||||||
|  |         if expand and len(sqrt_cov.shape) <= 2: | ||||||
|  |             sqrt_cov = th.diag_embed(sqrt_cov) | ||||||
|  |         return mean, sqrt_cov | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_cov(p: AnyDistribution): | ||||||
|  |     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||||
|  |         return th.diag_embed(p.variance) | ||||||
|  |     elif isinstance(p, th.distributions.MultivariateNormal): | ||||||
|  |         return p.covariance_matrix | ||||||
|  |     elif isinstance(p, SB3_Distribution): | ||||||
|  |         return get_cov(p.distribution) | ||||||
|  |     else: | ||||||
|  |         raise Exception('Dist-Type not implemented') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def has_diag_cov(p: AnyDistribution, numerical_check=False): | ||||||
|  |     if isinstance(p, SB3_Distribution): | ||||||
|  |         return has_diag_cov(p.distribution, numerical_check=numerical_check) | ||||||
|  |     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||||
|  |         return True | ||||||
|  |     if not numerical_check: | ||||||
|  |         return False | ||||||
|  |     # Check if matrix is diag | ||||||
|  |     cov = get_cov(p) | ||||||
|  |     return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1)), th.zeros_like(cov)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def is_contextual(p: AnyDistribution): | ||||||
|  |     # TODO: Implement for UniveralGaussianDist | ||||||
|  |     return False | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=False): | ||||||
|  |     if check_diag and not has_diag_cov(p, numerical_check=numerical_check): | ||||||
|  |         raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal') | ||||||
|  |     return th.diagonal(get_cov(p), dim1=-2, dim2=-1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): | ||||||
|  |     if isinstance(orig_p, UniversalGaussianDistribution): | ||||||
|  |         return orig_p.new_dist_like_me(mean, chol) | ||||||
|  |     elif isinstance(orig_p, th.distributions.Normal): | ||||||
|  |         if orig_p.stddev.shape != chol.shape: | ||||||
|  |             chol = th.diagonal(chol, dim1=1, dim2=2) | ||||||
|  |         return th.distributions.Normal(mean, chol) | ||||||
|  |     elif isinstance(orig_p, th.distributions.Independent): | ||||||
|  |         if orig_p.stddev.shape != chol.shape: | ||||||
|  |             chol = th.diagonal(chol, dim1=1, dim2=2) | ||||||
|  |         return th.distributions.Independent(th.distributions.Normal(mean, chol), 1) | ||||||
|  |     elif isinstance(orig_p, th.distributions.MultivariateNormal): | ||||||
|  |         return th.distributions.MultivariateNormal(mean, scale_tril=chol) | ||||||
|  |     elif isinstance(orig_p, SB3_Distribution): | ||||||
|  |         p = orig_p.distribution | ||||||
|  |         if isinstance(p, th.distributions.Normal): | ||||||
|  |             p_out = orig_p.__class__(orig_p.action_dim) | ||||||
|  |             p_out.distribution = th.distributions.Normal(mean, chol) | ||||||
|  |         elif isinstance(p, th.distributions.Independent): | ||||||
|  |             p_out = orig_p.__class__(orig_p.action_dim) | ||||||
|  |             p_out.distribution = th.distributions.Independent( | ||||||
|  |                 th.distributions.Normal(mean, chol), 1) | ||||||
|  |         elif isinstance(p, th.distributions.MultivariateNormal): | ||||||
|  |             p_out = orig_p.__class__(orig_p.action_dim) | ||||||
|  |             p_out.distribution = th.distributions.MultivariateNormal( | ||||||
|  |                 mean, scale_tril=chol) | ||||||
|  |         else: | ||||||
|  |             raise Exception('Dist-Type not implemented (of sb3 dist)') | ||||||
|  |         return p_out | ||||||
|  |     else: | ||||||
|  |         raise Exception('Dist-Type not implemented') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: th.Tensor): | ||||||
|  |     chol = _sqrt_to_chol(cov_sqrt) | ||||||
|  | 
 | ||||||
|  |     new = new_dist_like(orig_p, mean, chol) | ||||||
|  | 
 | ||||||
|  |     new.cov_sqrt = cov_sqrt | ||||||
|  |     if hasattr(new, 'distribution'): | ||||||
|  |         new.distribution.cov_sqrt = cov_sqrt | ||||||
|  | 
 | ||||||
|  |     return new | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _sqrt_to_chol(cov_sqrt): | ||||||
|  |     vec = False | ||||||
|  |     if len(cov_sqrt.shape) == 2: | ||||||
|  |         vec = True | ||||||
|  | 
 | ||||||
|  |     if vec: | ||||||
|  |         cov_sqrt = th.diag_embed(cov_sqrt) | ||||||
|  | 
 | ||||||
|  |     cov = th.bmm(cov_sqrt.mT, cov_sqrt) | ||||||
|  |     cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) | ||||||
|  | 
 | ||||||
|  |     chol = th.linalg.cholesky(cov) | ||||||
|  | 
 | ||||||
|  |     if vec: | ||||||
|  |         chol = th.diagonal(chol, dim1=-2, dim2=-1) | ||||||
|  | 
 | ||||||
|  |     return chol | ||||||
							
								
								
									
										31
									
								
								metastable_projections/misc/norm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								metastable_projections/misc/norm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | |||||||
|  | import torch as th | ||||||
|  | from torch.distributions.multivariate_normal import _batch_mahalanobis | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def mahalanobis_alt(u, v, std): | ||||||
|  |     """ | ||||||
|  |     Stolen from Fabian's Code (Public Version) | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  |     delta = u - v | ||||||
|  |     return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def mahalanobis(u, v, chol): | ||||||
|  |     delta = u - v | ||||||
|  |     return _batch_mahalanobis(chol, delta) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def frob_sq(diff, is_spd=False): | ||||||
|  |     # If diff is spd, we can use a (probably) more performant algorithm | ||||||
|  |     if is_spd: | ||||||
|  |         return _frob_sq_spd(diff) | ||||||
|  |     return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _frob_sq_spd(diff): | ||||||
|  |     return _batch_trace(diff @ diff) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _batch_trace(x): | ||||||
|  |     return th.diagonal(x, dim1=-2, dim2=-1).sum(-1) | ||||||
							
								
								
									
										4
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								setup.py
									
									
									
									
									
								
							| @ -1,12 +1,12 @@ | |||||||
| from setuptools import setup, find_packages | from setuptools import setup, find_packages | ||||||
| 
 | 
 | ||||||
| setup( | setup( | ||||||
|     name='metastable-baselines', |     name='metastable-projections', | ||||||
|     version='1.0.0', |     version='1.0.0', | ||||||
|     # url='https://github.com/mypackage.git', |     # url='https://github.com/mypackage.git', | ||||||
|     # author='Author Name', |     # author='Author Name', | ||||||
|     # author_email='author@gmail.com', |     # author_email='author@gmail.com', | ||||||
|     # description='Description of my package', |     # description='Description of my package', | ||||||
|     packages=['.'], |     packages=['.'], | ||||||
|     install_requires=['gym', 'stable_baselines3'], |     install_requires=['torch', 'stable_baselines3'], | ||||||
| ) | ) | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user