From d35c3d8520d75544138d9031fe16e8dd6a82392d Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Mon, 15 Aug 2022 16:55:17 +0200 Subject: [PATCH] Fixed all the bugs in TRPL --- metastable_baselines/distributions/distributions.py | 10 ++++++++-- metastable_baselines/misc/distTools.py | 7 ++++++- metastable_baselines/ppo/policies.py | 12 +++++++++++- test.py | 12 ++++++------ 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index 7521051..8f9ce4b 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -221,13 +221,19 @@ class UniversalGaussianDistribution(SB3_Distribution): def _sqrt_to_chol(self, cov_sqrt): vec = False - if len(cov_sqrt.shape) == 2: + nobatch = False + if len(cov_sqrt.shape) <= 2: vec = True + if len(cov_sqrt.shape) == 1: + nobatch = True if vec: cov_sqrt = th.diag_embed(cov_sqrt) - cov = th.bmm(cov_sqrt.mT, cov_sqrt) + if nobatch: + cov = th.mm(cov_sqrt.mT, cov_sqrt) + else: + cov = th.bmm(cov_sqrt.mT, cov_sqrt) chol = th.linalg.cholesky(cov) if vec: diff --git a/metastable_baselines/misc/distTools.py b/metastable_baselines/misc/distTools.py index 4983fc9..ef0eee0 100644 --- a/metastable_baselines/misc/distTools.py +++ b/metastable_baselines/misc/distTools.py @@ -26,7 +26,12 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution, expand=False): else: mean, chol = get_mean_and_chol(p, expand=False) sqrt_cov = p.cov_sqrt - if expand and len(sqrt_cov.shape) == 2: + if mean.shape[0] != sqrt_cov.shape[0]: + shape = list(sqrt_cov.shape) + shape[0] = mean.shape[0] + shape = tuple(shape) + sqrt_cov = sqrt_cov.expand(shape) + if expand and len(sqrt_cov.shape) <= 2: sqrt_cov = th.diag_embed(sqrt_cov) return mean, sqrt_cov diff --git a/metastable_baselines/ppo/policies.py b/metastable_baselines/ppo/policies.py index eb9d137..273aa79 100644 --- a/metastable_baselines/ppo/policies.py +++ b/metastable_baselines/ppo/policies.py @@ -315,9 +315,19 @@ class ActorCriticPolicy(BasePolicy): elif isinstance(self.action_dist, UniversalGaussianDistribution): if self.sqrt_induced_gaussian: chol_sqrt_cov = self.chol_net(latent_pi) - if len(chol_sqrt_cov.shape) == 2: + unembed = False + squeeze = False + if len(chol_sqrt_cov.shape) <= 2: + unembed = True chol_sqrt_cov = th.diag_embed(chol_sqrt_cov) + if len(chol_sqrt_cov.shape) <= 2: + squeeze = True + chol_sqrt_cov = chol_sqrt_cov.unsqueeze(0) cov_sqrt = th.bmm(chol_sqrt_cov.mT, chol_sqrt_cov) + if squeeze and False: + cov_sqrt = cov_sqrt.squeeze() + if unembed: + cov_sqrt = th.diagonal(cov_sqrt, dim1=-2, dim2=-1) dist = self.action_dist.proba_distribution_from_sqrt( mean_actions, cov_sqrt, latent_pi) mean, chol = get_mean_and_chol(dist, expand=False) diff --git a/test.py b/test.py index 2edce78..833402f 100755 --- a/test.py +++ b/test.py @@ -20,21 +20,21 @@ root_path = '.' def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0): env = gym.make(env_name) - use_sde = True + use_sde = False ppo = PPO( MlpPolicyPPO, env, - projection=BaseProjectionLayer(), - policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type': - ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, + projection=KLProjectionLayer(trust_region_coeff=0.01), + policy_kwargs={'dist_kwargs': {'neural_strength': Strength.SCALAR, 'cov_strength': Strength.DIAG, 'parameterization_type': + ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0, tensorboard_log=root_path+"/logs_tb/" + env_name+"/ppo"+(['', '_sde'][use_sde])+"/", - learning_rate=3e-4, + learning_rate=3e-4, # 3e-4, gamma=0.99, gae_lambda=0.95, normalize_advantage=True, - ent_coef=0.02, # 0.1 + ent_coef=0.1, # 0.1 vf_coef=0.5, use_sde=use_sde, # False clip_range=1 # 0.2,