Making MultivariateNormal Policies work (and porting Normal to
Independent)
This commit is contained in:
		
							parent
							
								
									b1ed9fc2b8
								
							
						
					
					
						commit
						ab557a8856
					
				| @ -4,7 +4,7 @@ from enum import Enum | ||||
| import gym | ||||
| import torch as th | ||||
| from torch import nn | ||||
| from torch.distributions import Normal, MultivariateNormal | ||||
| from torch.distributions import Normal, Independent, MultivariateNormal | ||||
| from math import pi | ||||
| 
 | ||||
| from stable_baselines3.common.preprocessing import get_action_dim | ||||
| @ -37,6 +37,7 @@ class Strength(Enum): | ||||
| 
 | ||||
| 
 | ||||
| class ParametrizationType(Enum): | ||||
|     NONE = 0 | ||||
|     CHOL = 1 | ||||
|     SPHERICAL_CHOL = 2 | ||||
|     # Not (yet?) implemented: | ||||
| @ -46,6 +47,7 @@ class ParametrizationType(Enum): | ||||
| 
 | ||||
| class EnforcePositiveType(Enum): | ||||
|     # TODO: Allow custom params for softplus? | ||||
|     NONE = (0, nn.Identity()) | ||||
|     SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20)) | ||||
|     ABS = (2, th.abs) | ||||
|     RELU = (3, nn.ReLU(inplace=False)) | ||||
| @ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng | ||||
|                 # TODO: Implement | ||||
|                 continue | ||||
|             if ps == Strength.NONE: | ||||
|                 yield (ps, cs, None, None) | ||||
|                 yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE) | ||||
|             else: | ||||
|                 for ept in allowedEPTs: | ||||
|                     if cs == Strength.FULL: | ||||
|                         for pt in allowedPTs: | ||||
|                             yield (ps, cs, ept, pt) | ||||
|                     else: | ||||
|                         yield (ps, cs, ept, None) | ||||
|                         yield (ps, cs, ept, ProbSquashingType.NONE) | ||||
| 
 | ||||
| 
 | ||||
| def make_proba_distribution( | ||||
| @ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|     :param action_dim:  Dimension of the action space. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): | ||||
|     def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): | ||||
|         super(UniversalGaussianDistribution, self).__init__() | ||||
|         self.action_dim = action_dim | ||||
|         self.par_strength = neural_strength | ||||
| @ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|         if use_sde: | ||||
|             raise Exception('SDE is not yet implemented') | ||||
| 
 | ||||
|         assert (parameterization_type != ParametrizationType.NONE) == ( | ||||
|             cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' | ||||
| 
 | ||||
|     def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor): | ||||
|         p = self.distribution | ||||
|         if isinstance(p, th.distributions.Normal): | ||||
|         if isinstance(p, Independent): | ||||
|             if p.stddev.shape != chol.shape: | ||||
|                 chol = th.diagonal(chol, dim1=1, dim2=2) | ||||
|             np = th.distributions.Normal(mean, chol) | ||||
|         elif isinstance(p, th.distributions.MultivariateNormal): | ||||
|             np = th.distributions.MultivariateNormal(mean, scale_tril=chol) | ||||
|             np = Independent(Normal(mean, chol), 1) | ||||
|         elif isinstance(p, MultivariateNormal): | ||||
|             np = MultivariateNormal(mean, scale_tril=chol) | ||||
|         new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, | ||||
|                                             parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) | ||||
|                                             parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) | ||||
|         new.distribution = np | ||||
| 
 | ||||
|         return new | ||||
| @ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|         # TODO: latent_pi is for SDE, implement. | ||||
| 
 | ||||
|         if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: | ||||
|             self.distribution = Normal(mean_actions, chol) | ||||
|             self.distribution = Independent(Normal(mean_actions, chol), 1) | ||||
|         elif self.cov_strength in [Strength.FULL]: | ||||
|             self.distribution = MultivariateNormal(mean_actions, cholesky=chol) | ||||
|             self.distribution = MultivariateNormal( | ||||
|                 mean_actions, scale_tril=chol) | ||||
|         if self.distribution == None: | ||||
|             raise Exception('Unable to create torch distribution') | ||||
|         return self | ||||
| @ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution): | ||||
|         :return: | ||||
|         """ | ||||
|         log_prob = self.distribution.log_prob(actions) | ||||
|         return sum_independent_dims(log_prob) | ||||
|         return log_prob | ||||
| 
 | ||||
|     def entropy(self) -> th.Tensor: | ||||
|         return sum_independent_dims(self.distribution.entropy()) | ||||
|         return self.distribution.entropy() | ||||
| 
 | ||||
|     def sample(self) -> th.Tensor: | ||||
|         # Reparametrization trick to pass gradients | ||||
|  | ||||
| @ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution | ||||
| 
 | ||||
| 
 | ||||
| def get_mean_and_chol(p: AnyDistribution, expand=False): | ||||
|     if isinstance(p, th.distributions.Normal): | ||||
|     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||
|         if expand: | ||||
|             return p.mean, th.diag_embed(p.stddev) | ||||
|         else: | ||||
| @ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution): | ||||
| 
 | ||||
| 
 | ||||
| def get_cov(p: AnyDistribution): | ||||
|     if isinstance(p, th.distributions.Normal): | ||||
|     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||
|         return th.diag_embed(p.variance) | ||||
|     elif isinstance(p, th.distributions.MultivariateNormal): | ||||
|         return p.covariance_matrix | ||||
| @ -45,7 +45,7 @@ def get_cov(p: AnyDistribution): | ||||
| def has_diag_cov(p: AnyDistribution, numerical_check=True): | ||||
|     if isinstance(p, SB3_Distribution): | ||||
|         return has_diag_cov(p.distribution, numerical_check=numerical_check) | ||||
|     if isinstance(p, th.distributions.Normal): | ||||
|     if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): | ||||
|         return True | ||||
|     if not numerical_check: | ||||
|         return False | ||||
| @ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True): | ||||
| 
 | ||||
| def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): | ||||
|     if isinstance(orig_p, UniversalGaussianDistribution): | ||||
|         return orig_p.new_list_like_me(mean, chol) | ||||
|         return orig_p.new_dist_like_me(mean, chol) | ||||
|     elif isinstance(orig_p, th.distributions.Normal): | ||||
|         if orig_p.stddev.shape != chol.shape: | ||||
|             chol = th.diagonal(chol, dim1=1, dim2=2) | ||||
|         return th.distributions.Normal(mean, chol) | ||||
|     elif isinstance(orig_p, th.distributions.Independent): | ||||
|         if orig_p.stddev.shape != chol.shape: | ||||
|             chol = th.diagonal(chol, dim1=1, dim2=2) | ||||
|         return th.distributions.Independent(th.distributions.Normal(mean, chol), 1) | ||||
|     elif isinstance(orig_p, th.distributions.MultivariateNormal): | ||||
|         return th.distributions.MultivariateNormal(mean, scale_tril=chol) | ||||
|     elif isinstance(orig_p, SB3_Distribution): | ||||
| @ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): | ||||
|         if isinstance(p, th.distributions.Normal): | ||||
|             p_out = orig_p.__class__(orig_p.action_dim) | ||||
|             p_out.distribution = th.distributions.Normal(mean, chol) | ||||
|         elif isinstance(p, th.distributions.Independent): | ||||
|             p_out = orig_p.__class__(orig_p.action_dim) | ||||
|             p_out.distribution = th.distributions.Independent( | ||||
|                 th.distributions.Normal(mean, chol), 1) | ||||
|         elif isinstance(p, th.distributions.MultivariateNormal): | ||||
|             p_out = orig_p.__class__(orig_p.action_dim) | ||||
|             p_out.distribution = th.distributions.MultivariateNormal( | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| from typing import Any, Dict, Optional, Type, Union, NamedTuple | ||||
| from more_itertools import distribute | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch as th | ||||
| @ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv | ||||
| from stable_baselines3.common.callbacks import BaseCallback | ||||
| from stable_baselines3.common.utils import obs_as_tensor | ||||
| 
 | ||||
| from ..misc.distTools import get_mean_and_chol | ||||
| from ..distributions.distributions import Strength, UniversalGaussianDistribution | ||||
| 
 | ||||
| # TRL requires the origina mean and covariance from the policy when the datapoint was created. | ||||
| # GaussianRolloutBuffer extends the RolloutBuffer by these two fields | ||||
| @ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass(): | ||||
|     def _setup_model(self) -> None: | ||||
|         super()._setup_model() | ||||
| 
 | ||||
|         cov_shape = self.action_space.shape | ||||
| 
 | ||||
|         if isinstance(self.policy.action_dist, UniversalGaussianDistribution): | ||||
|             if self.policy.action_dist.cov_strength == Strength.FULL: | ||||
|                 cov_shape = cov_shape + cov_shape | ||||
| 
 | ||||
|         self.rollout_buffer = GaussianRolloutBuffer( | ||||
|             self.n_steps, | ||||
|             self.observation_space, | ||||
| @ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass(): | ||||
|             gamma=self.gamma, | ||||
|             gae_lambda=self.gae_lambda, | ||||
|             n_envs=self.n_envs, | ||||
|             cov_shape=cov_shape, | ||||
|         ) | ||||
| 
 | ||||
|     def collect_rollouts( | ||||
|  | ||||
| @ -1,2 +1,2 @@ | ||||
| from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy | ||||
| from ..trl_pg.trl_pg import TRL_PG | ||||
| from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy | ||||
| from .ppo import PPO | ||||
|  | ||||
| @ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy): | ||||
|         normalize_images: bool = True, | ||||
|         optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, | ||||
|         optimizer_kwargs: Optional[Dict[str, Any]] = None, | ||||
|         dist_kwargs: Optional[Dict[str, Any]] = None, | ||||
|     ): | ||||
| 
 | ||||
|         if optimizer_kwargs is None: | ||||
| @ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy): | ||||
| 
 | ||||
|         self.normalize_images = normalize_images | ||||
|         self.log_std_init = log_std_init | ||||
|         dist_kwargs = None | ||||
|         # Keyword arguments for gSDE distribution | ||||
|         if use_sde: | ||||
|             dist_kwargs = { | ||||
|             add_dist_kwargs = { | ||||
|                 "full_std": full_std, | ||||
|                 "squash_output": squash_output, | ||||
|                 "use_expln": use_expln, | ||||
|                 "learn_features": False, | ||||
|             } | ||||
|             for k in add_dist_kwargs: | ||||
|                 dist_kwargs[k] = add_dist_kwargs[k] | ||||
| 
 | ||||
|         if sde_net_arch is not None: | ||||
|             warnings.warn( | ||||
|  | ||||
| @ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer | ||||
| from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass | ||||
| 
 | ||||
| 
 | ||||
| class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): | ||||
| class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): | ||||
|     """ | ||||
|     Differential Trust Region Layer (TRL) for Policy Gradient (PG) | ||||
| 
 | ||||
| @ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): | ||||
|                 q_dist = new_dist_like( | ||||
|                     p_dist, rollout_data.means, rollout_data.stds) | ||||
|                 proj_p = self.projection(p_dist, q_dist, self._global_steps) | ||||
|                 if isinstance(p_dist, th.distributions.Normal): | ||||
|                     # Normal uses a weird mapping from dimensions into batch_shape | ||||
|                     log_prob = proj_p.log_prob(actions).sum(dim=1) | ||||
|                 else: | ||||
|                     # UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping | ||||
|                     log_prob = proj_p.log_prob(actions) | ||||
|                 values = self.policy.value_net(latent_vf) | ||||
|                 entropy = proj_p.entropy() | ||||
| 
 | ||||
| @ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): | ||||
|         eval_env: Optional[GymEnv] = None, | ||||
|         eval_freq: int = -1, | ||||
|         n_eval_episodes: int = 5, | ||||
|         tb_log_name: str = "TRL_PG", | ||||
|         tb_log_name: str = "PPO", | ||||
|         eval_log_path: Optional[str] = None, | ||||
|         reset_num_timesteps: bool = True, | ||||
|     ) -> "TRL_PG": | ||||
|     ) -> "PPO": | ||||
| 
 | ||||
|         return super().learn( | ||||
|             total_timesteps=total_timesteps, | ||||
|  | ||||
							
								
								
									
										11
									
								
								replay.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								replay.py
									
									
									
									
									
								
							| @ -6,11 +6,10 @@ import os | ||||
| import time | ||||
| import datetime | ||||
| 
 | ||||
| from stable_baselines3 import SAC, PPO, A2C | ||||
| from stable_baselines3.common.evaluation import evaluate_policy | ||||
| from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy | ||||
| 
 | ||||
| from metastable_baselines.trl_pg import TRL_PG | ||||
| from metastable_baselines.ppo import PPO | ||||
| import columbus | ||||
| 
 | ||||
| 
 | ||||
| @ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0): | ||||
|     use_sde = file_name.find('sde') != -1 | ||||
|     print(env_name, alg_name, alg_deriv, use_sde) | ||||
|     env = gym.make(env_name) | ||||
|     if alg_name == 'ppo': | ||||
|         Model = PPO | ||||
|     elif alg_name == 'trl' and alg_deriv == 'pg': | ||||
|         Model = TRL_PG | ||||
|     else: | ||||
|         raise Exception('Algorithm not implemented for replay') | ||||
| 
 | ||||
|     model = Model.load(load_path, env=env) | ||||
|     model = PPO.load(load_path, env=env) | ||||
| 
 | ||||
|     if n_eval_episodes: | ||||
|         mean_reward, std_reward = evaluate_policy( | ||||
|  | ||||
							
								
								
									
										47
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								test.py
									
									
									
									
									
								
							| @ -6,25 +6,28 @@ import os | ||||
| import time | ||||
| import datetime | ||||
| 
 | ||||
| from stable_baselines3 import | ||||
| from stable_baselines3.common.evaluation import evaluate_policy | ||||
| from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy | ||||
| 
 | ||||
| from metastable_baselines.ppo import PPO | ||||
| # from metastable_baselines.sac import SAC | ||||
| from metastable_baselines.ppo.policies import MlpPolicy | ||||
| from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer | ||||
| import columbus | ||||
| 
 | ||||
| #root_path = os.getcwd() | ||||
| from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType | ||||
| 
 | ||||
| root_path = '.' | ||||
| 
 | ||||
| 
 | ||||
| def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0): | ||||
| def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0): | ||||
|     env = gym.make(env_name) | ||||
|     use_sde = False | ||||
|     ppo = PPO( | ||||
|         MlpPolicy, | ||||
|         env, | ||||
|         policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type': | ||||
|                        ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, | ||||
|         verbose=0, | ||||
|         tensorboard_log=root_path+"/logs_tb/" + | ||||
|         env_name+"/ppo"+(['', '_sde'][use_sde])+"/", | ||||
| @ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr | ||||
|         use_sde=use_sde,  # False | ||||
|         clip_range=0.2, | ||||
|     ) | ||||
|     trl_frob = PPO( | ||||
|         MlpPolicy, | ||||
|         env, | ||||
|         projection=FrobeniusProjectionLayer(), | ||||
|         verbose=0, | ||||
|         tensorboard_log=root_path+"/logs_tb/"+env_name + | ||||
|         "/trl_frob"+(['', '_sde'][use_sde])+"/", | ||||
|         learning_rate=3e-4, | ||||
|         gamma=0.99, | ||||
|         gae_lambda=0.95, | ||||
|         normalize_advantage=True, | ||||
|         ent_coef=0.03,  # 0.1 | ||||
|         vf_coef=0.5, | ||||
|         use_sde=use_sde, | ||||
|         clip_range=2,  # 0.2 | ||||
|     ) | ||||
|     # trl_frob = PPO( | ||||
|     #    MlpPolicy, | ||||
|     #    env, | ||||
|     #    projection=FrobeniusProjectionLayer(), | ||||
|     #    verbose=0, | ||||
|     #    tensorboard_log=root_path+"/logs_tb/"+env_name + | ||||
|     #    "/trl_frob"+(['', '_sde'][use_sde])+"/", | ||||
|     #    learning_rate=3e-4, | ||||
|     #    gamma=0.99, | ||||
|     #    gae_lambda=0.95, | ||||
|     #    normalize_advantage=True, | ||||
|     #    ent_coef=0.03,  # 0.1 | ||||
|     #    vf_coef=0.5, | ||||
|     #    use_sde=use_sde, | ||||
|     #    clip_range=2,  # 0.2 | ||||
|     # ) | ||||
| 
 | ||||
|     print('PPO:') | ||||
|     testModel(ppo, timesteps, showRes, | ||||
|               saveModel, n_eval_episodes) | ||||
|     print('TRL_frob:') | ||||
|     testModel(trl_frob, timesteps, showRes, | ||||
|               saveModel, n_eval_episodes) | ||||
|     # print('TRL_frob:') | ||||
|     # testModel(trl_frob, timesteps, showRes, | ||||
|     #          saveModel, n_eval_episodes) | ||||
| 
 | ||||
| 
 | ||||
| def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user