CHanges to work with new version of metastable baselines 2
This commit is contained in:
		
							parent
							
								
									a5309e0fb8
								
							
						
					
					
						commit
						8a438e275f
					
				| @ -13,7 +13,7 @@ class BaseProjectionLayer(object): | |||||||
|                  mean_bound: float = 0.03, |                  mean_bound: float = 0.03, | ||||||
|                  cov_bound: float = 1e-3, |                  cov_bound: float = 1e-3, | ||||||
|                  trust_region_coeff: float = 1.0, |                  trust_region_coeff: float = 1.0, | ||||||
|                  scale_prec: bool = True, |                  scale_prec: bool = False, | ||||||
|                  do_entropy_proj: bool = False, |                  do_entropy_proj: bool = False, | ||||||
|                  entropy_eq: bool = False, |                  entropy_eq: bool = False, | ||||||
|                  entropy_first: bool = False, |                  entropy_first: bool = False, | ||||||
| @ -26,6 +26,10 @@ class BaseProjectionLayer(object): | |||||||
|         self.scale_prec = scale_prec |         self.scale_prec = scale_prec | ||||||
|         self.mean_eq = False |         self.mean_eq = False | ||||||
| 
 | 
 | ||||||
|  |         assert not entropy_eq, 'Sorry pal; thats actually not implemented yet.' | ||||||
|  |         assert not entropy_first, 'Sorry pal; thats actually not implemented yet.' | ||||||
|  |         assert not do_entropy_proj, 'Sorry pal; thats actually not implemented yet.' | ||||||
|  | 
 | ||||||
|         self.entropy_first = entropy_first |         self.entropy_first = entropy_first | ||||||
|         self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection |         self.entropy_proj = entropy_equality_projection if entropy_eq else entropy_inequality_projection | ||||||
| 
 | 
 | ||||||
| @ -63,6 +67,7 @@ class BaseProjectionLayer(object): | |||||||
| 
 | 
 | ||||||
|         #################################################################################################################### |         #################################################################################################################### | ||||||
|         # entropy projection in the end |         # entropy projection in the end | ||||||
|  | 
 | ||||||
|         if not self.do_entropy_proj or self.entropy_first: |         if not self.do_entropy_proj or self.entropy_first: | ||||||
|             return new_p |             return new_p | ||||||
| 
 | 
 | ||||||
| @ -81,7 +86,7 @@ class BaseProjectionLayer(object): | |||||||
|         Returns: |         Returns: | ||||||
|             projected_dist, old_dist (from rollouts) |             projected_dist, old_dist (from rollouts) | ||||||
|         """ |         """ | ||||||
|         old_distribution = self.new_dist_like(dist, rollout_data.means, rollout_data.cov_decomps) |         old_distribution = self.new_dist_like(dist, rollout_data.mean, rollout_data.cov_decomp) | ||||||
|         return self(dist, old_distribution, **kwargs), old_distribution |         return self(dist, old_distribution, **kwargs), old_distribution | ||||||
| 
 | 
 | ||||||
|     def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): |     def _trust_region_projection(self, p, q, eps: th.Tensor, eps_cov: th.Tensor, **kwargs): | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user