From 4532135812db0b3479bb509c2afa8ffacc716dfa Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 3 Sep 2022 11:59:16 +0200 Subject: [PATCH] Finalized factoring out projections --- metastable_baselines/misc/norm.py | 31 ------------------------------- metastable_baselines/ppo/ppo.py | 3 +-- test.py | 2 +- 3 files changed, 2 insertions(+), 34 deletions(-) delete mode 100644 metastable_baselines/misc/norm.py diff --git a/metastable_baselines/misc/norm.py b/metastable_baselines/misc/norm.py deleted file mode 100644 index 894451b..0000000 --- a/metastable_baselines/misc/norm.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch as th -from torch.distributions.multivariate_normal import _batch_mahalanobis - - -def mahalanobis_alt(u, v, std): - """ - Stolen from Fabian's Code (Public Version) - - """ - delta = u - v - return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1]) - - -def mahalanobis(u, v, chol): - delta = u - v - return _batch_mahalanobis(chol, delta) - - -def frob_sq(diff, is_spd=False): - # If diff is spd, we can use a (probably) more performant algorithm - if is_spd: - return _frob_sq_spd(diff) - return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2) - - -def _frob_sq_spd(diff): - return _batch_trace(diff @ diff) - - -def _batch_trace(x): - return th.diagonal(x, dim1=-2, dim2=-1).sum(-1) diff --git a/metastable_baselines/ppo/ppo.py b/metastable_baselines/ppo/ppo.py index 9e106d4..b25845e 100644 --- a/metastable_baselines/ppo/ppo.py +++ b/metastable_baselines/ppo/ppo.py @@ -336,8 +336,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm): policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \ trust_region_loss + self.action_coef * action_loss - import pdb - pdb.set_trace() + pg_losses.append(policy_loss.item()) loss = policy_loss + self.vf_coef * value_loss diff --git a/test.py b/test.py index 174fad9..a3f2a95 100755 --- a/test.py +++ b/test.py @@ -28,7 +28,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru MlpPolicyPPO, env, # KLProjectionLayer(trust_region_coeff=0.01), - projection=KLProjectionLayer(trust_region_coeff=0.01), + projection=WassersteinProjectionLayer(trust_region_coeff=0.01), policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type': ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0,