From ab557a88560c37477d3f8ce2121012ae99ed99f8 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 15 Jul 2022 15:03:51 +0200 Subject: [PATCH] Making MultivariateNormal Policies work (and porting Normal to Independent) --- .../distributions/distributions.py | 32 +++++++----- metastable_baselines/misc/distTools.py | 16 ++++-- metastable_baselines/misc/rollout_buffer.py | 10 ++++ metastable_baselines/ppo/__init__.py | 4 +- metastable_baselines/ppo/policies.py | 6 ++- metastable_baselines/ppo/ppo.py | 13 +++-- replay.py | 11 +---- test.py | 49 ++++++++++--------- 8 files changed, 84 insertions(+), 57 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index bace3df..a67014e 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -4,7 +4,7 @@ from enum import Enum import gym import torch as th from torch import nn -from torch.distributions import Normal, MultivariateNormal +from torch.distributions import Normal, Independent, MultivariateNormal from math import pi from stable_baselines3.common.preprocessing import get_action_dim @@ -37,6 +37,7 @@ class Strength(Enum): class ParametrizationType(Enum): + NONE = 0 CHOL = 1 SPHERICAL_CHOL = 2 # Not (yet?) implemented: @@ -46,6 +47,7 @@ class ParametrizationType(Enum): class EnforcePositiveType(Enum): # TODO: Allow custom params for softplus? + NONE = (0, nn.Identity()) SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20)) ABS = (2, th.abs) RELU = (3, nn.ReLU(inplace=False)) @@ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng # TODO: Implement continue if ps == Strength.NONE: - yield (ps, cs, None, None) + yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE) else: for ept in allowedEPTs: if cs == Strength.FULL: for pt in allowedPTs: yield (ps, cs, ept, pt) else: - yield (ps, cs, ept, None) + yield (ps, cs, ept, ProbSquashingType.NONE) def make_proba_distribution( @@ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution): :param action_dim: Dimension of the action space. """ - def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): + def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE): super(UniversalGaussianDistribution, self).__init__() self.action_dim = action_dim self.par_strength = neural_strength @@ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution): if use_sde: raise Exception('SDE is not yet implemented') + assert (parameterization_type != ParametrizationType.NONE) == ( + cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full' + def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor): p = self.distribution - if isinstance(p, th.distributions.Normal): + if isinstance(p, Independent): if p.stddev.shape != chol.shape: chol = th.diagonal(chol, dim1=1, dim2=2) - np = th.distributions.Normal(mean, chol) - elif isinstance(p, th.distributions.MultivariateNormal): - np = th.distributions.MultivariateNormal(mean, scale_tril=chol) + np = Independent(Normal(mean, chol), 1) + elif isinstance(p, MultivariateNormal): + np = MultivariateNormal(mean, scale_tril=chol) new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength, - parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) + parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type) new.distribution = np return new @@ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution): # TODO: latent_pi is for SDE, implement. if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]: - self.distribution = Normal(mean_actions, chol) + self.distribution = Independent(Normal(mean_actions, chol), 1) elif self.cov_strength in [Strength.FULL]: - self.distribution = MultivariateNormal(mean_actions, cholesky=chol) + self.distribution = MultivariateNormal( + mean_actions, scale_tril=chol) if self.distribution == None: raise Exception('Unable to create torch distribution') return self @@ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution): :return: """ log_prob = self.distribution.log_prob(actions) - return sum_independent_dims(log_prob) + return log_prob def entropy(self) -> th.Tensor: - return sum_independent_dims(self.distribution.entropy()) + return self.distribution.entropy() def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 6d3e5ba..2e2c0d0 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution def get_mean_and_chol(p: AnyDistribution, expand=False): - if isinstance(p, th.distributions.Normal): + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): if expand: return p.mean, th.diag_embed(p.stddev) else: @@ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution): def get_cov(p: AnyDistribution): - if isinstance(p, th.distributions.Normal): + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): return th.diag_embed(p.variance) elif isinstance(p, th.distributions.MultivariateNormal): return p.covariance_matrix @@ -45,7 +45,7 @@ def get_cov(p: AnyDistribution): def has_diag_cov(p: AnyDistribution, numerical_check=True): if isinstance(p, SB3_Distribution): return has_diag_cov(p.distribution, numerical_check=numerical_check) - if isinstance(p, th.distributions.Normal): + if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent): return True if not numerical_check: return False @@ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True): def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): if isinstance(orig_p, UniversalGaussianDistribution): - return orig_p.new_list_like_me(mean, chol) + return orig_p.new_dist_like_me(mean, chol) elif isinstance(orig_p, th.distributions.Normal): if orig_p.stddev.shape != chol.shape: chol = th.diagonal(chol, dim1=1, dim2=2) return th.distributions.Normal(mean, chol) + elif isinstance(orig_p, th.distributions.Independent): + if orig_p.stddev.shape != chol.shape: + chol = th.diagonal(chol, dim1=1, dim2=2) + return th.distributions.Independent(th.distributions.Normal(mean, chol), 1) elif isinstance(orig_p, th.distributions.MultivariateNormal): return th.distributions.MultivariateNormal(mean, scale_tril=chol) elif isinstance(orig_p, SB3_Distribution): @@ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor): if isinstance(p, th.distributions.Normal): p_out = orig_p.__class__(orig_p.action_dim) p_out.distribution = th.distributions.Normal(mean, chol) + elif isinstance(p, th.distributions.Independent): + p_out = orig_p.__class__(orig_p.action_dim) + p_out.distribution = th.distributions.Independent( + th.distributions.Normal(mean, chol), 1) elif isinstance(p, th.distributions.MultivariateNormal): p_out = orig_p.__class__(orig_p.action_dim) p_out.distribution = th.distributions.MultivariateNormal( diff --git a/metastable_baselines/misc/rollout_buffer.py b/metastable_baselines/misc/rollout_buffer.py index d53cd26..8bec087 100644 --- a/metastable_baselines/misc/rollout_buffer.py +++ b/metastable_baselines/misc/rollout_buffer.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Optional, Type, Union, NamedTuple +from more_itertools import distribute import numpy as np import torch as th @@ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.utils import obs_as_tensor +from ..misc.distTools import get_mean_and_chol +from ..distributions.distributions import Strength, UniversalGaussianDistribution # TRL requires the origina mean and covariance from the policy when the datapoint was created. # GaussianRolloutBuffer extends the RolloutBuffer by these two fields @@ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass(): def _setup_model(self) -> None: super()._setup_model() + cov_shape = self.action_space.shape + + if isinstance(self.policy.action_dist, UniversalGaussianDistribution): + if self.policy.action_dist.cov_strength == Strength.FULL: + cov_shape = cov_shape + cov_shape + self.rollout_buffer = GaussianRolloutBuffer( self.n_steps, self.observation_space, @@ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass(): gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, + cov_shape=cov_shape, ) def collect_rollouts( diff --git a/metastable_baselines/ppo/__init__.py b/metastable_baselines/ppo/__init__.py index 8938d03..28bb6ed 100644 --- a/metastable_baselines/ppo/__init__.py +++ b/metastable_baselines/ppo/__init__.py @@ -1,2 +1,2 @@ -from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy -from ..trl_pg.trl_pg import TRL_PG +from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from .ppo import PPO diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index 11579f9..63a99c9 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy): normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, + dist_kwargs: Optional[Dict[str, Any]] = None, ): if optimizer_kwargs is None: @@ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy): self.normalize_images = normalize_images self.log_std_init = log_std_init - dist_kwargs = None # Keyword arguments for gSDE distribution if use_sde: - dist_kwargs = { + add_dist_kwargs = { "full_std": full_std, "squash_output": squash_output, "use_expln": use_expln, "learn_features": False, } + for k in add_dist_kwargs: + dist_kwargs[k] = add_dist_kwargs[k] if sde_net_arch is not None: warnings.warn( diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index f083bbf..dd86154 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass -class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): +class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): """ Differential Trust Region Layer (TRL) for Policy Gradient (PG) @@ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): q_dist = new_dist_like( p_dist, rollout_data.means, rollout_data.stds) proj_p = self.projection(p_dist, q_dist, self._global_steps) - log_prob = proj_p.log_prob(actions).sum(dim=1) + if isinstance(p_dist, th.distributions.Normal): + # Normal uses a weird mapping from dimensions into batch_shape + log_prob = proj_p.log_prob(actions).sum(dim=1) + else: + # UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping + log_prob = proj_p.log_prob(actions) values = self.policy.value_net(latent_vf) entropy = proj_p.entropy() @@ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): eval_env: Optional[GymEnv] = None, eval_freq: int = -1, n_eval_episodes: int = 5, - tb_log_name: str = "TRL_PG", + tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "TRL_PG": + ) -> "PPO": return super().learn( total_timesteps=total_timesteps, diff --git a/replay.py b/replay.py index db9c5fd..f4554fc 100755 --- a/replay.py +++ b/replay.py @@ -6,11 +6,10 @@ import os import time import datetime -from stable_baselines3 import SAC, PPO, A2C from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy -from metastable_baselines.trl_pg import TRL_PG +from metastable_baselines.ppo import PPO import columbus @@ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0): use_sde = file_name.find('sde') != -1 print(env_name, alg_name, alg_deriv, use_sde) env = gym.make(env_name) - if alg_name == 'ppo': - Model = PPO - elif alg_name == 'trl' and alg_deriv == 'pg': - Model = TRL_PG - else: - raise Exception('Algorithm not implemented for replay') - model = Model.load(load_path, env=env) + model = PPO.load(load_path, env=env) if n_eval_episodes: mean_reward, std_reward = evaluate_policy( diff --git a/test.py b/test.py index 085b73a..65b5b71 100755 --- a/test.py +++ b/test.py @@ -6,25 +6,28 @@ import os import time import datetime -from stable_baselines3 import from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy from metastable_baselines.ppo import PPO +# from metastable_baselines.sac import SAC from metastable_baselines.ppo.policies import MlpPolicy from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer import columbus -#root_path = os.getcwd() +from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType + root_path = '.' -def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0): +def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0): env = gym.make(env_name) use_sde = False ppo = PPO( MlpPolicy, env, + policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type': + ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0, tensorboard_log=root_path+"/logs_tb/" + env_name+"/ppo"+(['', '_sde'][use_sde])+"/", @@ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr use_sde=use_sde, # False clip_range=0.2, ) - trl_frob = PPO( - MlpPolicy, - env, - projection=FrobeniusProjectionLayer(), - verbose=0, - tensorboard_log=root_path+"/logs_tb/"+env_name + - "/trl_frob"+(['', '_sde'][use_sde])+"/", - learning_rate=3e-4, - gamma=0.99, - gae_lambda=0.95, - normalize_advantage=True, - ent_coef=0.03, # 0.1 - vf_coef=0.5, - use_sde=use_sde, - clip_range=2, # 0.2 - ) + # trl_frob = PPO( + # MlpPolicy, + # env, + # projection=FrobeniusProjectionLayer(), + # verbose=0, + # tensorboard_log=root_path+"/logs_tb/"+env_name + + # "/trl_frob"+(['', '_sde'][use_sde])+"/", + # learning_rate=3e-4, + # gamma=0.99, + # gae_lambda=0.95, + # normalize_advantage=True, + # ent_coef=0.03, # 0.1 + # vf_coef=0.5, + # use_sde=use_sde, + # clip_range=2, # 0.2 + # ) print('PPO:') testModel(ppo, timesteps, showRes, saveModel, n_eval_episodes) - print('TRL_frob:') - testModel(trl_frob, timesteps, showRes, - saveModel, n_eval_episodes) + # print('TRL_frob:') + # testModel(trl_frob, timesteps, showRes, + # saveModel, n_eval_episodes) def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16): @@ -95,7 +98,7 @@ def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes= env.render() episode_reward += reward if done: - #print("Reward:", episode_reward) + # print("Reward:", episode_reward) episode_reward = 0.0 obs = env.reset() env.reset()