Working on SDC
This commit is contained in:
		
							parent
							
								
									4c4b12ee0e
								
							
						
					
					
						commit
						e4440428f8
					
				| @ -28,31 +28,39 @@ class Strength(Enum): | |||||||
|     DIAG = 2 |     DIAG = 2 | ||||||
|     FULL = 3 |     FULL = 3 | ||||||
| 
 | 
 | ||||||
|     # def __init__(self, num): |  | ||||||
|     #    self.num = num |  | ||||||
| 
 |  | ||||||
|     # @property |  | ||||||
|     # def foo(self): |  | ||||||
|     #    return self.num |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class ParametrizationType(Enum): | class ParametrizationType(Enum): | ||||||
|  |     # Currently only Chol is implemented | ||||||
|     CHOL = 1 |     CHOL = 1 | ||||||
|     SPHERICAL_CHOL = 2 |     #SPHERICAL_CHOL = 2 | ||||||
|     GIVENS = 3 |     #GIVENS = 3 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class EnforcePositiveType(Enum): | class EnforcePositiveType(Enum): | ||||||
|     SOFTPLUS = 1 |     # TODO: Allow custom params for softplus? | ||||||
|     ABS = 2 |     SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20)) | ||||||
|     RELU = 3 |     ABS = (2, th.abs) | ||||||
|     SELU = 4 |     RELU = (3, nn.ReLU(inplace=False)) | ||||||
|     LOG = 5 |     LOG = (4, th.log) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, value, func): | ||||||
|  |         self.value = value | ||||||
|  |         self._func = func | ||||||
|  | 
 | ||||||
|  |     def apply(self, x): | ||||||
|  |         return self._func(x) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ProbSquashingType(Enum): | class ProbSquashingType(Enum): | ||||||
|     NONE = 0 |     NONE = (0, nn.Identity()) | ||||||
|     TANH = 1 |     TANH = (1, th.tanh) | ||||||
|  | 
 | ||||||
|  |     def __init__(self, value, func): | ||||||
|  |         self.value = value | ||||||
|  |         self._func = func | ||||||
|  | 
 | ||||||
|  |     def apply(self, x): | ||||||
|  |         return self._func(x) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): | def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStrength=None, allowedPTs=None, allowedPSTs=None): | ||||||
| @ -118,10 +126,10 @@ class UniversalGaussianDistribution(SB3_Distribution): | |||||||
| 
 | 
 | ||||||
|         :param latent_dim: Dimension of the last layer of the policy (before the action layer) |         :param latent_dim: Dimension of the last layer of the policy (before the action layer) | ||||||
|         :param log_std_init: Initial value for the log standard deviation |         :param log_std_init: Initial value for the log standard deviation | ||||||
|         :return: We return two nn.Modules (mean, pseudo_chol). chol can return a díagonal vector when it would only be a diagonal matrix. |         :return: We return two nn.Modules (mean, chol). | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         # TODO: Rename pseudo_cov to pseudo_chol |         # TODO: Allow chol to be vector when only diagonal. | ||||||
| 
 | 
 | ||||||
|         mean_actions = nn.Linear(latent_dim, self.action_dim) |         mean_actions = nn.Linear(latent_dim, self.action_dim) | ||||||
| 
 | 
 | ||||||
| @ -139,33 +147,33 @@ class UniversalGaussianDistribution(SB3_Distribution): | |||||||
|                 # TODO: Off-axis init? |                 # TODO: Off-axis init? | ||||||
|                 pseudo_cov_par = nn.Parameter( |                 pseudo_cov_par = nn.Parameter( | ||||||
|                     th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) |                     th.diag_embed(th.ones(self.action_dim) * log_std_init), requires_grad=True) | ||||||
|             pseudo_cov = FakeModule(pseudo_cov_par) |             chol = FakeModule(pseudo_cov_par) | ||||||
|         elif self.par_strength == self.cov_strength: |         elif self.par_strength == self.cov_strength: | ||||||
|             if self.par_strength == Strength.NONE: |             if self.par_strength == Strength.NONE: | ||||||
|                 pseudo_cov = FakeModule(th.ones(self.action_dim)) |                 chol = FakeModule(th.ones(self.action_dim)) | ||||||
|             elif self.par_strength == Strength.SCALAR: |             elif self.par_strength == Strength.SCALAR: | ||||||
|                 # TODO: Does it work like this? Test! |                 # TODO: Does it work like this? Test! | ||||||
|                 std = nn.Linear(latent_dim, 1) |                 std = nn.Linear(latent_dim, 1) | ||||||
|                 pseudo_cov = th.ones(self.action_dim) * std |                 chol = th.ones(self.action_dim) * std | ||||||
|             elif self.par_strength == Strength.DIAG: |             elif self.par_strength == Strength.DIAG: | ||||||
|                 pseudo_cov = nn.Linear(latent_dim, self.action_dim) |                 chol = nn.Linear(latent_dim, self.action_dim) | ||||||
|             elif self.par_strength == Strength.FULL: |             elif self.par_strength == Strength.FULL: | ||||||
|                 pseudo_cov = self._parameterize_full(latent_dim) |                 chol = self._parameterize_full(latent_dim) | ||||||
|         elif self.par_strength > self.cov_strength: |         elif self.par_strength > self.cov_strength: | ||||||
|             raise Exception( |             raise Exception( | ||||||
|                 'The parameterization can not be stronger than the actual covariance.') |                 'The parameterization can not be stronger than the actual covariance.') | ||||||
|         else: |         else: | ||||||
|             if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: |             if self.par_strength == Strength.SCALAR and self.cov_strength == Strength.DIAG: | ||||||
|                 pseudo_cov = self._parameterize_hybrid_from_scalar(latent_dim) |                 chol = self._parameterize_hybrid_from_scalar(latent_dim) | ||||||
|             elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: |             elif self.par_strength == Strength.DIAG and self.cov_strength == Strength.FULL: | ||||||
|                 pseudo_cov = self._parameterize_hybrid_from_diag(latent_dim) |                 chol = self._parameterize_hybrid_from_diag(latent_dim) | ||||||
|             elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: |             elif self.par_strength == Strength.SCALAR and self.cov_strength == Strength.FULL: | ||||||
|                 raise Exception( |                 raise Exception( | ||||||
|                     'That does not even make any sense...') |                     'That does not even make any sense...') | ||||||
|             else: |             else: | ||||||
|                 raise Exception("This Exception can't happen (I think)") |                 raise Exception("This Exception can't happen (I think)") | ||||||
| 
 | 
 | ||||||
|         return mean_actions, pseudo_cov |         return mean_actions, chol | ||||||
| 
 | 
 | ||||||
|     def _parameterize_full(self, latent_dim): |     def _parameterize_full(self, latent_dim): | ||||||
|         # TODO: Implement various techniques for full parameterization (forcing SPD) |         # TODO: Implement various techniques for full parameterization (forcing SPD) | ||||||
| @ -177,6 +185,13 @@ class UniversalGaussianDistribution(SB3_Distribution): | |||||||
|         raise Exception( |         raise Exception( | ||||||
|             'Programmer-was-to-lazy-to-implement-this-Exception') |             'Programmer-was-to-lazy-to-implement-this-Exception') | ||||||
| 
 | 
 | ||||||
|  |     def _ensure_positive_func(self, x): | ||||||
|  |         return self.enforce_positive_type.apply(x) | ||||||
|  | 
 | ||||||
|  |     def _ensure_diagonal_positive(self, pseudo_chol): | ||||||
|  |         pseudo_chol.tril(-1) + self._ensure_positive_func(pseudo_chol.diagonal(dim1=-2, | ||||||
|  |                                                                                dim2=-1)).diag_embed() + pseudo_chol.triu(1) | ||||||
|  | 
 | ||||||
|     def _parameterize_hybrid_from_scalar(self, latent_dim): |     def _parameterize_hybrid_from_scalar(self, latent_dim): | ||||||
|         factor = nn.Linear(latent_dim, 1) |         factor = nn.Linear(latent_dim, 1) | ||||||
|         par_cov = th.ones(self.action_dim) * \ |         par_cov = th.ones(self.action_dim) * \ | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user