Fixed bugs for Hybrid[Diag=>Full]
This commit is contained in:
		
							parent
							
								
									ad584d70fd
								
							
						
					
					
						commit
						cb9ee4f302
					
				@ -288,6 +288,9 @@ class CholNet(nn.Module):
 | 
			
		||||
                self.param = nn.Parameter(
 | 
			
		||||
                    th.ones(self.action_dim), requires_grad=True)
 | 
			
		||||
            elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
 | 
			
		||||
                if self.enforce_positive_type == EnforcePositiveType.NONE:
 | 
			
		||||
                    raise Exception(
 | 
			
		||||
                        'For Hybrid[Diag=>Full] enforce_positive_type has to be not NONE. Otherwise required SPD-contraint can not be ensured for cov.')
 | 
			
		||||
                self.stds = nn.Linear(latent_dim, self.action_dim)
 | 
			
		||||
                self.padder = th.nn.ZeroPad2d((0, 1, 1, 0))
 | 
			
		||||
                # TODO: Init Non-zero?
 | 
			
		||||
@ -332,12 +335,16 @@ class CholNet(nn.Module):
 | 
			
		||||
                return diag_chol
 | 
			
		||||
            elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL:
 | 
			
		||||
                # TODO: Maybe possible to improve speed and stability by making conversion from pearson correlation + stds to cov in cholesky-form.
 | 
			
		||||
                stds = self.stds(x)
 | 
			
		||||
                stds = self._ensure_positive_func(self.stds(x))
 | 
			
		||||
                smol = self._parameterize_full(self.params)
 | 
			
		||||
                big = self.padder(smol)
 | 
			
		||||
                pearson_cor_chol = big + th.eye(stds.shape[0])
 | 
			
		||||
                try:
 | 
			
		||||
                    pearson_cor_chol = big + th.eye(stds.shape[-1])
 | 
			
		||||
                except:
 | 
			
		||||
                    import pdb
 | 
			
		||||
                    pdb.set_trace()
 | 
			
		||||
                pearson_cor = pearson_cor_chol.T @ pearson_cor_chol
 | 
			
		||||
                cov = stds * pearson_cor * stds
 | 
			
		||||
                cov = stds.T * pearson_cor * stds
 | 
			
		||||
                chol = th.linalg.cholesky(cov)
 | 
			
		||||
                return chol
 | 
			
		||||
            elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user