Fix issues when providing diagonal matrix sqrt
This commit is contained in:
		
							parent
							
								
									7ad4858e8a
								
							
						
					
					
						commit
						f129928635
					
				| @ -27,20 +27,23 @@ def get_mean_and_chol(p: AnyDistribution, expand=False): | ||||
| 
 | ||||
| 
 | ||||
| def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): | ||||
|     mean, chol = get_mean_and_chol(p, expand=False) | ||||
|     if not hasattr(p, 'cov_sqrt'): | ||||
|         raise Exception( | ||||
|             'Distribution was not induced from sqrt. On-demand calculation is not supported.') | ||||
|         if len(p.distribution.scale.shape)==2: # In the factorized case, chol = matSqrt | ||||
|             sqrt_cov = p.distribution.scale | ||||
|         else: | ||||
|             raise Exception( | ||||
|                 'Distribution was not induced from sqrt. On-demand calculation is not supported.') | ||||
|     else: | ||||
|         mean, chol = get_mean_and_chol(p, expand=False) | ||||
|         sqrt_cov = p.cov_sqrt | ||||
|         if mean.shape[0] != sqrt_cov.shape[0]: | ||||
|             shape = list(sqrt_cov.shape) | ||||
|             shape[0] = mean.shape[0] | ||||
|             shape = tuple(shape) | ||||
|             sqrt_cov = sqrt_cov.expand(shape) | ||||
|         if expand and len(sqrt_cov.shape) <= 2: | ||||
|             sqrt_cov = th.diag_embed(sqrt_cov) | ||||
|         return mean, sqrt_cov | ||||
|     if mean.shape[0] != sqrt_cov.shape[0]: | ||||
|         shape = list(sqrt_cov.shape) | ||||
|         shape[0] = mean.shape[0] | ||||
|         shape = tuple(shape) | ||||
|         sqrt_cov = sqrt_cov.expand(shape) | ||||
|     if expand and len(sqrt_cov.shape) <= 2: | ||||
|         sqrt_cov = th.diag_embed(sqrt_cov) | ||||
|     return mean, sqrt_cov | ||||
| 
 | ||||
| 
 | ||||
| def get_cov(p: AnyDistribution): | ||||
| @ -123,12 +126,14 @@ def new_dist_like_from_sqrt(orig_p: AnyDistribution, mean: th.Tensor, cov_sqrt: | ||||
| 
 | ||||
| 
 | ||||
| def _sqrt_to_chol(cov_sqrt, only_diag=False): | ||||
|     if only_diag: | ||||
|         return cov_sqrt | ||||
|     cov = th.bmm(cov_sqrt.mT, cov_sqrt) | ||||
|     cov += th.eye(cov.shape[-1]).expand(cov.shape)*(1e-6) | ||||
| 
 | ||||
|     chol = th.linalg.cholesky(cov) | ||||
| 
 | ||||
|     if only_diag: | ||||
|         chol = th.diagonal(chol, dim1=-2, dim2=-1) | ||||
|     #if only_diag: | ||||
|     #    chol = th.diagonal(chol, dim1=-2, dim2=-1) | ||||
| 
 | ||||
|     return chol | ||||
|  | ||||
| @ -86,6 +86,7 @@ class BaseProjectionLayer(object): | ||||
|         Returns: | ||||
|             projected_dist, old_dist (from rollouts) | ||||
|         """ | ||||
| 
 | ||||
|         old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp) | ||||
|         return self(dist, old_distribution, **kwargs), old_distribution | ||||
| 
 | ||||
|  | ||||
| @ -4,7 +4,7 @@ from typing import Tuple | ||||
| from .base_projection_layer import BaseProjectionLayer, mean_projection | ||||
| 
 | ||||
| from ..misc.norm import mahalanobis, frob_sq | ||||
| from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like | ||||
| from ..misc.distTools import get_mean_and_chol, get_cov, new_dist_like, has_diag_cov | ||||
| 
 | ||||
| 
 | ||||
| class FrobeniusProjectionLayer(BaseProjectionLayer): | ||||
| @ -57,6 +57,9 @@ class FrobeniusProjectionLayer(BaseProjectionLayer): | ||||
|         else: | ||||
|             proj_chol = chol | ||||
| 
 | ||||
|         if has_diag_cov(p): | ||||
|             proj_chol = th.diagonal(proj_chol, dim1=-2, dim2=-1) | ||||
| 
 | ||||
|         proj_p = new_dist_like(p, proj_mean, proj_chol) | ||||
|         return proj_p | ||||
| 
 | ||||
|  | ||||
| @ -63,6 +63,9 @@ class WassersteinProjectionLayer(BaseProjectionLayer): | ||||
|         else: | ||||
|             proj_sqrt = sqrt | ||||
| 
 | ||||
|         if has_diag_cov(p): | ||||
|             proj_sqrt = th.diagonal(proj_sqrt, dim1=-2, dim2=-1) | ||||
| 
 | ||||
|         proj_p = new_dist_like_from_sqrt(p, proj_mean, proj_sqrt) | ||||
|         return proj_p | ||||
| 
 | ||||
| @ -110,6 +113,7 @@ class WassersteinProjectionLayer(BaseProjectionLayer): | ||||
|                 mean, scale_tril=cov_sqrt) | ||||
|         else: | ||||
|             raise Exception('Dist-Type not implemented (of sb3 dist)') | ||||
|         p_out.cov_sqrt = cov_sqrt | ||||
|         return p_out | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user